diff --git a/.ci/tests/examples/api_test.py b/.ci/tests/examples/api_test.py new file mode 100644 index 000000000..e9a5bd06d --- /dev/null +++ b/.ci/tests/examples/api_test.py @@ -0,0 +1,83 @@ +import fire +import yaml + +from fedn import APIClient + + +def _download_config(output): + """ Download the client configuration file from the controller. + + :param output: The output file path. + :type output: str + """ + client = APIClient(host="localhost", port=8092) + config = client.get_client_config(checksum=True) + with open(output, 'w') as f: + f.write(yaml.dump(config)) + + +def test_api_get_methods(): + client = APIClient(host="localhost", port=8092) + status = client.get_controller_status() + assert status + print("Controller status: ", status, flush=True) + + events = client.get_events() + assert events + print("Events: ", events, flush=True) + + validations = client.list_validations() + assert validations + print("Validations: ", validations, flush=True) + + models = client.get_model_trail() + assert models + print("Models: ", models, flush=True) + + clients = client.list_clients() + assert clients + print("Clients: ", clients, flush=True) + + combiners = client.list_combiners() + assert combiners + print("Combiners: ", combiners, flush=True) + + combiner = client.get_combiner("combiner") + assert combiner + print("Combiner: ", combiner, flush=True) + + first_model = client.get_initial_model() + assert first_model + print("First model: ", first_model, flush=True) + + package = client.get_package() + assert package + print("Package: ", package, flush=True) + + checksum = client.get_package_checksum() + assert checksum + print("Checksum: ", checksum, flush=True) + + rounds = client.list_rounds() + assert rounds + print("Rounds: ", rounds, flush=True) + + round = client.get_round(1) + assert round + print("Round: ", round, flush=True) + + sessions = client.list_sessions() + assert sessions + print("Sessions: ", sessions, flush=True) + + +if __name__ == '__main__': + + client = APIClient(host="localhost", port=8092) + fire.Fire({ + 'set_seed': client.set_initial_model, + 'set_package': client.set_package, + 'start_session': client.start_session, + 'get_client_config': _download_config, + 'test_api_get_methods': test_api_get_methods, + }) diff --git a/.ci/tests/examples/print_logs.sh b/.ci/tests/examples/print_logs.sh index 4c63f141e..6979000ed 100755 --- a/.ci/tests/examples/print_logs.sh +++ b/.ci/tests/examples/print_logs.sh @@ -5,8 +5,11 @@ docker logs "$(basename $PWD)_minio_1" echo "Mongo logs" docker logs "$(basename $PWD)_mongo_1" -echo "Reducer logs" -docker logs "$(basename $PWD)_reducer_1" +echo "Dashboard logs" +docker logs "$(basename $PWD)_dashboard_1" + +echo "API-Server logs" +docker logs "$(basename $PWD)_api-server_1" echo "Combiner logs" docker logs "$(basename $PWD)_combiner_1" diff --git a/.ci/tests/examples/run.sh b/.ci/tests/examples/run.sh index 7afe8b9cc..da1c7981f 100755 --- a/.ci/tests/examples/run.sh +++ b/.ci/tests/examples/run.sh @@ -23,34 +23,23 @@ docker-compose \ ".$example/bin/python" ../../.ci/tests/examples/wait_for.py combiners >&2 echo "Upload compute package" -curl -k -X POST \ - -F file=@package.tgz \ - -F helper="$helper" \ - http://localhost:8090/context -printf '\n' +".$example/bin/python" ../../.ci/tests/examples/api_test.py set_package --path package.tgz --helper "$helper" >&2 echo "Upload seed" -curl -k -X POST \ - -F seed=@seed.npz \ - http://localhost:8090/models -printf '\n' +".$example/bin/python" ../../.ci/tests/examples/api_test.py set_seed --path seed.npz >&2 echo "Wait for clients to connect" ".$example/bin/python" ../../.ci/tests/examples/wait_for.py clients ->&2 echo "Start round" -curl -k -X POST \ - -F rounds=3 \ - -F validate=True \ - http://localhost:8090/control -printf '\n' +>&2 echo "Start session" +".$example/bin/python" ../../.ci/tests/examples/api_test.py start_session --rounds 3 --helper "$helper" >&2 echo "Checking rounds success" ".$example/bin/python" ../../.ci/tests/examples/wait_for.py rounds >&2 echo "Test client connection with dowloaded settings" # Get config -curl -k http://localhost:8090/config/download > ../../client.yaml +".$example/bin/python" ../../.ci/tests/examples/api_test.py get_client_config --output ../../client.yaml # Redeploy clients with config docker-compose \ @@ -62,5 +51,8 @@ docker-compose \ >&2 echo "Wait for clients to reconnect" ".$example/bin/python" ../../.ci/tests/examples/wait_for.py clients +>&2 echo "Test API GET requests" +".$example/bin/python" ../../.ci/tests/examples/api_test.py test_api_get_methods + popd >&2 echo "Test completed successfully" \ No newline at end of file diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 7fa75506d..dc3345da0 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -37,21 +37,31 @@ def _test_rounds(n_rounds): return n == n_rounds -def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8090'): +def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092'): try: - resp = requests.get( - f'http://{reducer_host}:{reducer_port}/netgraph', verify=False) + + endpoint = "list_clients" if node_type == "client" else "list_combiners" + + response = requests.get( + f'http://{reducer_host}:{reducer_port}/{endpoint}', verify=False) + + if response.status_code == 200: + + data = json.loads(response.content) + + count = 0 + if node_type == "client": + arr = data.get('result') + count = sum(element.get('status') == "online" for element in arr) + else: + count = data.get('count') + + _eprint(f'Active {node_type}s: {count}.') + return count == n_nodes + except Exception as e: _eprint(f'Reques exception econuntered: {e}.') return False - if resp.status_code == 200: - gr = json.loads(resp.content) - n = sum(values.get('type') == node_type and values.get( - 'status') == 'active' for values in gr['nodes']) - _eprint(f'Active {node_type}s: {n}.') - return n == n_nodes - _eprint(f'Reducer returned {resp.status_code}.') - return False def rounds(n_rounds=3): diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index c1ec38548..3b0f615f6 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -18,6 +18,8 @@ jobs: --skip .venv --skip .mnist-keras --skip .mnist-pytorch + --skip fedn_pb2.py + --skip fedn_pb2_grpc.py - name: check Python formatting run: > @@ -25,12 +27,14 @@ jobs: --exclude .venv --exclude .mnist-keras --exclude .mnist-pytorch + --exclude fedn_pb2.py + --exclude fedn_pb2_grpc.py . - name: run Python linter run: > .venv/bin/flake8 . - --exclude ".venv,.mnist-keras,.mnist-pytorch,fedn_pb2.py" + --exclude ".venv,.mnist-keras,.mnist-pytorch,fedn_pb2.py,fedn_pb2_grpc.py" - name: check for floating imports run: > @@ -38,7 +42,8 @@ jobs: --exclude-dir='.venv' --exclude-dir='.mnist-pytorch' --exclude-dir='.mnist-keras' - --exclude-dir='docs' + --exclude-dir='docs' + --exclude='tests.py' '^[ \t]+(import|from) ' -I . # TODO: add linting/formatting for all file types \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0ccd43f49..8a86fd439 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,4 +6,9 @@ build: python: "3.9" sphinx: - configuration: docs/conf.py \ No newline at end of file + configuration: docs/conf.py + +python: + install: + - method: pip + path: ./fedn diff --git a/README.rst b/README.rst index 42503985a..2afb60ebc 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ To connect a client that uses the data partition 'data/clients/1/mnist.pt': -v $PWD/data/clients/1:/var/data \ -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ --network=fedn_default \ - ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch run client -in client.yaml --name client1 + ghcr.io/scaleoutsystems/fedn/fedn:master-mnist-pytorch run client -in client.yaml --name client1 You are now ready to start training the model at http://localhost:8090/control. diff --git a/config/settings-client.yaml.template b/config/settings-client.yaml.template index e4035f8d9..d7146af26 100644 --- a/config/settings-client.yaml.template +++ b/config/settings-client.yaml.template @@ -1,3 +1,3 @@ network_id: fedn-network -discover_host: reducer -discover_port: 8090 +discover_host: api-server +discover_port: 8092 diff --git a/config/settings-combiner.yaml.template b/config/settings-combiner.yaml.template index 68deff143..8cef6643a 100644 --- a/config/settings-combiner.yaml.template +++ b/config/settings-combiner.yaml.template @@ -1,6 +1,6 @@ network_id: fedn-network -discover_host: reducer -discover_port: 8090 +discover_host: api-server +discover_port: 8092 name: combiner host: combiner diff --git a/config/settings-reducer.yaml.template b/config/settings-reducer.yaml.template index 3289b656a..fd9352331 100644 --- a/config/settings-reducer.yaml.template +++ b/config/settings-reducer.yaml.template @@ -1,4 +1,8 @@ network_id: fedn-network +controller: + host: api-server + port: 8092 + debug: True statestore: type: MongoDB diff --git a/docker-compose.yaml b/docker-compose.yaml index 610d6e4c0..c8d3aff15 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -58,14 +58,13 @@ services: ports: - 8081:8081 - # Reducer - reducer: + dashboard: environment: - GET_HOSTS_FROM=dns - USER=test - PROJECT=project - FLASK_DEBUG=1 - - FLASK_ENV=development + - STATESTORE_CONFIG=/app/config/settings-reducer.yaml build: context: . args: @@ -75,10 +74,34 @@ services: - ${HOST_REPO_DIR:-.}/fedn:/app/fedn entrypoint: [ "sh", "-c" ] command: - - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run reducer -n reducer --init=config/settings-reducer.yaml" + - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run dashboard -n reducer --init=config/settings-reducer.yaml" ports: - 8090:8090 + api-server: + environment: + - GET_HOSTS_FROM=dns + - USER=test + - PROJECT=project + - FLASK_DEBUG=1 + - STATESTORE_CONFIG=/app/config/settings-reducer.yaml + - MODELSTORAGE_CONFIG=/app/config/settings-reducer.yaml + build: + context: . + args: + BASE_IMG: ${BASE_IMG:-python:3.9-slim} + working_dir: /app + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + depends_on: + - minio + - mongo + entrypoint: [ "sh", "-c" ] + command: + - "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/python fedn/fedn/network/api/server.py" + ports: + - 8092:8092 + # Combiner combiner: environment: diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..ab335dcea --- /dev/null +++ b/docs/README.md @@ -0,0 +1,4 @@ +FEDn is using sphinx with reStructuredText. + +sphinx-apidoc --ext-autodoc --module-first -o _source ../fedn/fedn ../*tests* ../*exceptions* ../*common* ../ ../fedn/fedn/network/api/server.py ../fedn/fedn/network/controller/controlbase.py +sphinx-build . _build \ No newline at end of file diff --git a/docs/architecture.rst b/docs/architecture.rst index 6e04d2af0..0a770c5b6 100644 --- a/docs/architecture.rst +++ b/docs/architecture.rst @@ -1,53 +1,63 @@ Architecture overview ===================== -Constructing a federated model with FEDn amounts to a) specifying the details of the client-side training code and data integrations, and b) deploying the reducer-combiner network. A FEDn network, as illustrated in the picture below, is made up of three main components: the *Reducer*, one or more *Combiners*, and a number of *Clients*. The combiner network forms the backbone of the FedML orchestration mechanism, while the Reducer provides discovery services and provides controls to coordinate training over the combiner network. By horizontally scaling the combiner network, one can meet the needs of a growing number of clients. +Constructing a federated model with FEDn amounts to a) specifying the details of the client-side training code and data integrations, and b) deploying the federated network. A FEDn network, as illustrated in the picture below, is made up of components into three different tiers: the *Controller* tier (3), one or more *Combiners* in second tier (2), and a number of *Clients* in tier (1). +The combiners forms the backbone of the federated ML orchestration mechanism, while the Controller tier provides discovery services and controls to coordinate training over the federated network. +By horizontally scaling the number of combiners, one can meet the needs of a growing number of clients. -.. image:: img/overview.png +.. image:: img/FEDn_network.png :alt: FEDn network :width: 100% :align: center -Main components ---------------- -Client -...... -A Client is a data node, holding private data and connecting to a Combiner to receive model update requests and model validation requests during training rounds. Importantly, clients do not require any open ingress ports. A client receives the code to be executed from the Reducer upon connecting to the network, and thus they only need to be configured prior to connection to read the local datasets during training and validation. Python3 client implementation is provided out of the box, and it is possible to write clients in a variety of languages to target different software and hardware requirements. -Combiner -........ +The clients: tier 1 +................... -A combiner is an actor whose main role is to orchestrate and aggregate model updates from a number of clients during a training round. When and how to trigger such orchestration rounds are specified in the overall *compute plan* laid out by the Reducer. Each combiner in the network runs an independent gRPC server, providing RPCs for interacting with the alliance subsystem it controls. Hence, the total number of clients that can be accommodated in a FEDn network is proportional to the number of active combiners in the FEDn network. Combiners can be deployed anywhere, e.g. in a cloud or on a fog node to provide aggregation services near the cloud edge. +A Client (gRPC client) is a data node, holding private data and connecting to a Combiner (gRPC server) to receive model update requests and model validation requests during training sessions. +Importantly, clients uses remote procedure calls (RPC) to ask for model updates tasks, thus the clients not require any open ingress ports! A client receives the code (called package or compute package) to be executed from the *Controller* +upon connecting to the network, and thus they only need to be configured prior to connection to read the local datasets during training and validation. The package is based on entry points in the client code, and can be customized to fit the needs of the user. +This allows for a high degree of flexibility in terms of what kind of training and validation tasks that can be performed on the client side. Such as different types of machine learning models and framework, and even programming languages. +A python3 client implementation is provided out of the box, and it is possible to write clients in a variety of languages to target different software and hardware requirements. -Reducer -....... +The combiners: tier 2 +..................... -The reducer fills three main roles in the FEDn network: 1.) it lays out the overall, global training strategy and communicates that to the combiner network. It also dictates the strategy to aggregate model updates from individual combiners into a single global model, 2.) it handles global state and maintains the *model trail* - an immutable trail of global model updates uniquely defining the FedML training timeline, and 3.) it provides discovery services, mediating connections between clients and combiners. For this purpose, the Reducer exposes a standard REST API. +A combiner is an actor whose main role is to orchestrate and aggregate model updates from a number of clients during a training session. +When and how to trigger such orchestration are specified in the overall *compute plan* laid out by the *Controller*. +Each combiner in the network runs an independent gRPC server, providing RPCs for interacting with the federated network it controls. +Hence, the total number of clients that can be accommodated in a FEDn network is proportional to the number of active combiners in the FEDn network. +Combiners can be deployed anywhere, e.g. in a cloud or on a fog node to provide aggregation services near the cloud edge. -Services and communication --------------------------- +The controller: tier 3 +...................... -The figure below provides a logical architecture view of the services provided by each agent and how they interact. +Tier 3 does actually contain several components and services, but we tend to associate it with the *Controller* the most. The *Controller* fills three main roles in the FEDn network: -.. image:: img/FEDn-arch-overview.png - :alt: FEDn architecture overview - :width: 100% - :align: center +1. it lays out the overall, global training strategy and communicates that to the combiner network. +It also dictates the strategy to aggregate model updates from individual combiners into a single global model, +2. it handles global state and maintains the *model trail* - an immutable trail of global model updates uniquely defining the federated ML training timeline, and +3. it provides discovery services, mediating connections between clients and combiners. For this purpose, the *Controller* exposes a standard REST API both for RPC clients and servers, but also for user interfaces and other services. + +Tier 3 also contain a *Reducer* component, which is responsible for aggregating combiner-level models into a single global model. Further, it contains a *StateStore* database, +which is responsible for storing various states of the network and training sessions. The final global model trail from a traning session is stored in the *ModelRegistry* database. -Control flows and algorithms ----------------------------- +Notes on aggregating algorithms +............................... -FEDn is designed to allow customization of the FedML algorithm, following a specified pattern, or programming model. Model aggregation happens on two levels in the system. First, each Combiner can be configured with a custom orchestration and aggregation implementation, that reduces model updates from Clients into a single, *combiner level* model. Then, a configurable aggregation protocol on the Reducer level is responsible for combining the combiner-level models into a global model. By varying the aggregation schemes on the two levels in the system, many different possible outcomes can be achieved. Good staring configurations are provided out-of-the-box to help the user get started. +FEDn is designed to allow customization of the FedML algorithm, following a specified pattern, or programming model. +Model aggregation happens on two levels in the network. First, each Combiner can be configured with a custom orchestration and aggregation implementation, that reduces model updates from Clients into a single, *combiner level* model. +Then, a configurable aggregation protocol on the *Controller* level is responsible for combining the combiner-level models into a global model. By varying the aggregation schemes on the two levels in the system, +many different possible outcomes can be achieved. Good starting configurations are provided out-of-the-box to help the user get started. See API reference for more details. Hierarchical Federated Averaging ................................ -The currently implemented default scheme uses a local SGD strategy on the Combiner level aggregation and a simple average of models on the reducer level. This results in a highly horizontally scalable FedAvg scheme. The strategy works well with most artificial neural network (ANNs) models, and can in general be applied to models where it is possible and makes sense to form mean values of model parameters (for example SVMs). Additional FedML training protocols, including support for various types of federated ensemble models, are in active development. +The currently implemented default scheme uses a local SGD strategy on the Combiner level aggregation and a simple average of models on the reducer level. +This results in a highly horizontally scalable FedAvg scheme. The strategy works well with most artificial neural network (ANNs) models, +and can in general be applied to models where it is possible and makes sense to form mean values of model parameters (for example SVMs). + -.. image:: img/HFedAvg.png - :alt: FEDn architecture overview - :width: 100% - :align: center diff --git a/docs/conf.py b/docs/conf.py index 8133e96fc..963080333 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -106,3 +106,5 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'https://docs.python.org/': None} + +pygments_style = 'sphinx' diff --git a/docs/deployment.rst b/docs/deployment.rst index ff8f3ea0e..974d98842 100644 --- a/docs/deployment.rst +++ b/docs/deployment.rst @@ -1,4 +1,4 @@ -Deployment +Distributed Deployment ====================== This guide serves as reference deployment for setting up a FEDn network consisting of: @@ -29,7 +29,7 @@ The reducer and clients need to be able to resolve the hostname for the combiner we show how this can be achieved if no external DNS resolution is available, by setting "extra host" in the Docker containers for the Reducer and client. Note that there are many other possible ways to achieve this, depending on your setup. 1. Deploy storage and database services (MinIO, MongoDB and MongoExpress) --------------------------------------------------------------------- +------------------------------------------------------------------------- First, deploy MinIO and Mongo services on one of the hosts. Edit the `docker-compose.yaml` file to change the default passwords and ports. diff --git a/docs/faq.rst b/docs/faq.rst index b3eab3c6b..948e53d57 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -71,7 +71,7 @@ Q: How can I configure the round validity policy: In the main control implementation https://github.com/scaleoutsystems/fedn/blob/master/fedn/fedn/clients/reducer/control.py you can modify or replace the wiwmethod "check_round_validity_policy". As we expand with more implementations of this policy, we plan to make it runtime configurable. Q: Can I start a client listening only to training requests or only on validation requests?: -------------------------------------------------- +-------------------------------------------------------------------------------------------- Yes! From FEDn 0.3.0 there is an option to toggle which message streams a client subscibes to. For example, to start a pure validation client: diff --git a/docs/fedn.network.api.rst b/docs/fedn.network.api.rst new file mode 100644 index 000000000..b14090da3 --- /dev/null +++ b/docs/fedn.network.api.rst @@ -0,0 +1,34 @@ +fedn.network.api package +======================== + +.. automodule:: fedn.network.api + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.api.client module +------------------------------ + +.. automodule:: fedn.network.api.client + :members: + :undoc-members: + :show-inheritance: + +fedn.network.api.interface module +--------------------------------- + +.. automodule:: fedn.network.api.interface + :members: + :undoc-members: + :show-inheritance: + +fedn.network.api.network module +------------------------------- + +.. automodule:: fedn.network.api.network + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.clients.rst b/docs/fedn.network.clients.rst new file mode 100644 index 000000000..81b070aa1 --- /dev/null +++ b/docs/fedn.network.clients.rst @@ -0,0 +1,42 @@ +fedn.network.clients package +============================ + +.. automodule:: fedn.network.clients + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.clients.client module +---------------------------------- + +.. automodule:: fedn.network.clients.client + :members: + :undoc-members: + :show-inheritance: + +fedn.network.clients.connect module +----------------------------------- + +.. automodule:: fedn.network.clients.connect + :members: + :undoc-members: + :show-inheritance: + +fedn.network.clients.package module +----------------------------------- + +.. automodule:: fedn.network.clients.package + :members: + :undoc-members: + :show-inheritance: + +fedn.network.clients.state module +--------------------------------- + +.. automodule:: fedn.network.clients.state + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.combiner.aggregators.rst b/docs/fedn.network.combiner.aggregators.rst new file mode 100644 index 000000000..a26abf1f4 --- /dev/null +++ b/docs/fedn.network.combiner.aggregators.rst @@ -0,0 +1,26 @@ +fedn.network.combiner.aggregators package +========================================= + +.. automodule:: fedn.network.combiner.aggregators + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.combiner.aggregators.aggregatorbase module +------------------------------------------------------- + +.. automodule:: fedn.network.combiner.aggregators.aggregatorbase + :members: + :undoc-members: + :show-inheritance: + +fedn.network.combiner.aggregators.fedavg module +----------------------------------------------- + +.. automodule:: fedn.network.combiner.aggregators.fedavg + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.combiner.rst b/docs/fedn.network.combiner.rst new file mode 100644 index 000000000..a894200f4 --- /dev/null +++ b/docs/fedn.network.combiner.rst @@ -0,0 +1,58 @@ +fedn.network.combiner package +============================= + +.. automodule:: fedn.network.combiner + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fedn.network.combiner.aggregators + +Submodules +---------- + +fedn.network.combiner.connect module +------------------------------------ + +.. automodule:: fedn.network.combiner.connect + :members: + :undoc-members: + :show-inheritance: + +fedn.network.combiner.interfaces module +--------------------------------------- + +.. automodule:: fedn.network.combiner.interfaces + :members: + :undoc-members: + :show-inheritance: + +fedn.network.combiner.modelservice module +----------------------------------------- + +.. automodule:: fedn.network.combiner.modelservice + :members: + :undoc-members: + :show-inheritance: + +fedn.network.combiner.round module +---------------------------------- + +.. automodule:: fedn.network.combiner.round + :members: + :undoc-members: + :show-inheritance: + +fedn.network.combiner.server module +----------------------------------- + +.. automodule:: fedn.network.combiner.server + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.controller.rst b/docs/fedn.network.controller.rst new file mode 100644 index 000000000..a0e995805 --- /dev/null +++ b/docs/fedn.network.controller.rst @@ -0,0 +1,18 @@ +fedn.network.controller package +=============================== + +.. automodule:: fedn.network.controller + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.controller.control module +-------------------------------------- + +.. automodule:: fedn.network.controller.control + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.dashboard.rst b/docs/fedn.network.dashboard.rst new file mode 100644 index 000000000..25ee3e8d8 --- /dev/null +++ b/docs/fedn.network.dashboard.rst @@ -0,0 +1,26 @@ +fedn.network.dashboard package +============================== + +.. automodule:: fedn.network.dashboard + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.dashboard.plots module +----------------------------------- + +.. automodule:: fedn.network.dashboard.plots + :members: + :undoc-members: + :show-inheritance: + +fedn.network.dashboard.restservice module +----------------------------------------- + +.. automodule:: fedn.network.dashboard.restservice + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.loadbalancer.rst b/docs/fedn.network.loadbalancer.rst new file mode 100644 index 000000000..7934f2228 --- /dev/null +++ b/docs/fedn.network.loadbalancer.rst @@ -0,0 +1,34 @@ +fedn.network.loadbalancer package +================================= + +.. automodule:: fedn.network.loadbalancer + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.loadbalancer.firstavailable module +----------------------------------------------- + +.. automodule:: fedn.network.loadbalancer.firstavailable + :members: + :undoc-members: + :show-inheritance: + +fedn.network.loadbalancer.leastpacked module +-------------------------------------------- + +.. automodule:: fedn.network.loadbalancer.leastpacked + :members: + :undoc-members: + :show-inheritance: + +fedn.network.loadbalancer.loadbalancerbase module +------------------------------------------------- + +.. automodule:: fedn.network.loadbalancer.loadbalancerbase + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.rst b/docs/fedn.network.rst new file mode 100644 index 000000000..2e0ccd753 --- /dev/null +++ b/docs/fedn.network.rst @@ -0,0 +1,48 @@ +fedn.network package +==================== + +.. automodule:: fedn.network + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fedn.network.api + fedn.network.clients + fedn.network.combiner + fedn.network.controller + fedn.network.dashboard + fedn.network.loadbalancer + fedn.network.statestore + +Submodules +---------- + +fedn.network.config module +-------------------------- + +.. automodule:: fedn.network.config + :members: + :undoc-members: + :show-inheritance: + +fedn.network.reducer module +--------------------------- + +.. automodule:: fedn.network.reducer + :members: + :undoc-members: + :show-inheritance: + +fedn.network.state module +------------------------- + +.. automodule:: fedn.network.state + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.network.statestore.rst b/docs/fedn.network.statestore.rst new file mode 100644 index 000000000..06d2d4607 --- /dev/null +++ b/docs/fedn.network.statestore.rst @@ -0,0 +1,26 @@ +fedn.network.statestore package +=============================== + +.. automodule:: fedn.network.statestore + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.network.statestore.mongostatestore module +---------------------------------------------- + +.. automodule:: fedn.network.statestore.mongostatestore + :members: + :undoc-members: + :show-inheritance: + +fedn.network.statestore.statestorebase module +--------------------------------------------- + +.. automodule:: fedn.network.statestore.statestorebase + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.rst b/docs/fedn.rst new file mode 100644 index 000000000..0ef5dce19 --- /dev/null +++ b/docs/fedn.rst @@ -0,0 +1,16 @@ +fedn (python package) +===================== + +.. automodule:: fedn + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fedn.network + fedn.utils diff --git a/docs/fedn.utils.plugins.rst b/docs/fedn.utils.plugins.rst new file mode 100644 index 000000000..adc4ee88a --- /dev/null +++ b/docs/fedn.utils.plugins.rst @@ -0,0 +1,42 @@ +fedn.utils.plugins package +========================== + +.. automodule:: fedn.utils.plugins + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +fedn.utils.plugins.helperbase module +------------------------------------ + +.. automodule:: fedn.utils.plugins.helperbase + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.plugins.kerashelper module +------------------------------------- + +.. automodule:: fedn.utils.plugins.kerashelper + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.plugins.numpyarrayhelper module +------------------------------------------ + +.. automodule:: fedn.utils.plugins.numpyarrayhelper + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.plugins.pytorchhelper module +--------------------------------------- + +.. automodule:: fedn.utils.plugins.pytorchhelper + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/fedn.utils.rst b/docs/fedn.utils.rst new file mode 100644 index 000000000..7fcc67d44 --- /dev/null +++ b/docs/fedn.utils.rst @@ -0,0 +1,58 @@ +fedn.utils package +================== + +.. automodule:: fedn.utils + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fedn.utils.plugins + +Submodules +---------- + +fedn.utils.checksum module +-------------------------- + +.. automodule:: fedn.utils.checksum + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.dispatcher module +---------------------------- + +.. automodule:: fedn.utils.dispatcher + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.helpers module +------------------------- + +.. automodule:: fedn.utils.helpers + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.logger module +------------------------ + +.. automodule:: fedn.utils.logger + :members: + :undoc-members: + :show-inheritance: + +fedn.utils.process module +------------------------- + +.. automodule:: fedn.utils.process + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/img/FEDn_network.png b/docs/img/FEDn_network.png new file mode 100644 index 000000000..76d1a53d2 Binary files /dev/null and b/docs/img/FEDn_network.png differ diff --git a/docs/index.rst b/docs/index.rst index b4bc2a16f..fe253738a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,12 +1,20 @@ -.. include:: ../README.rst - -Table of Contents ------------------ .. toctree:: :maxdepth: 2 :caption: Table of Contents + introduction + quickstart architecture deployment + interfaces tutorial - faq \ No newline at end of file + faq + modules + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` \ No newline at end of file diff --git a/docs/interfaces.rst b/docs/interfaces.rst new file mode 100644 index 000000000..f32020261 --- /dev/null +++ b/docs/interfaces.rst @@ -0,0 +1,37 @@ +User interfaces +=============== + +FEDn comes with an *APIClient* and a *Dashboard* for interacting with the FEDn network. The APIClient is a Python3 library that can be used to interact with the FEDn network programmatically. +The Dashboard is a web-based user interface that can be used to interact with the FEDn network through a web browser. + +APIClient +-------------- +The APIClient is a Python3 library that can be used to interact with the FEDn network programmatically. The APIClient is available as a Python package on PyPI, and can be installed using pip: + +.. code-block:: bash + + $ pip install fedn + + +To initialize the APIClient, you need to provide the hostname and port of the FEDn API server. The default port is 8092. The following code snippet shows how to initialize the APIClient: + +.. code-block:: python + + from fedn import APIClient + client = APIClient("localhost", 8092) + +For more information on how to use the APIClient, see the :py:mod:`fedn.network.api.client`. + +Dashboard +-------------- +The Dashboard is a web-based user interface that can be used to interact with the FEDn network through a web browser. The Dashboard is available as a Docker image, and can be run using the following command: + +.. code:: bash + + $ docker-compose up -d dashboard + +OBS! If you have followed any of the examples, the dashboard will already be running! +The Dashboard is now available at http://localhost:8090. If no compute package has been configured, the Dashboard will ask you to upload a compute package. +A compute package is a zip file containing the ML code that will be executed on the clients. +For more information on how to create a compute package, see the :ref:`compute package documentation `. After uploading a compute package, you will also need to upload an initial model. This initial model is +usually the initial weights for the model that will be trained. You can then navigate to the Control Panel to start a training session. diff --git a/docs/introduction.rst b/docs/introduction.rst new file mode 100644 index 000000000..6897690ba --- /dev/null +++ b/docs/introduction.rst @@ -0,0 +1,61 @@ +Introduction to Federated Learning +================================== + +Federated Learning stands at the forefront of modern machine learning techniques, offering a novel approach to address challenges related to data privacy, security, +and decentralized data distribution. In contrast to traditional machine learning setups where data is collected and stored centrally, +Federated Learning allows for collaborative model training while keeping data localized. This innovative paradigm proves to be particularly advantageous in +scenarios where data cannot be easily shared due to privacy regulations, network limitations, or ownership concerns. + +At its core, Federated Learning orchestrates model training across distributed devices or servers, referred to as clients or participants. +These participants could be diverse endpoints such as mobile devices, IoT gadgets, or remote servers. Rather than transmitting raw data to a central location, +each participant computes gradients locally based on its data. These gradients are then communicated to a central server, often called the aggregator or orchestrator. +The central server aggregates and combines the gradients from multiple participants to update a global model. +This iterative process allows the global model to improve without the need to share the raw data. + +FEDn: the SDK for scalable federated learning +--------------------------------------------- + +FEDn serves as a System Development Kit (SDK) tailored for scalable federated learning. +It is used to implement the core server side logic (including model aggregation) and the client side integrations. +It implements functionality to deploy and scale the server side in geographically distributed setups. +Developers and ML engineers can use FEDn to build custom federated learning systems and bespoke deployments. + + +One of the standout features of FEDn is its ability to deploy and scale the server-side in geographically distributed setups, +adapting to varying project needs and geographical considerations. + + +Scalable and Resilient +...................... + +FEDn exhibits scalability and resilience, thanks to its multi-tiered architecture. Multiple aggregation servers, known as combiners, +form a network to divide the workload, coordinating clients, and aggregating models. +This architecture allows for high performance in various settings, from thousands of clients in a cross-device environment to +large model updates in a cross-silo scenario. Crucially, FEDn has built-in recovery capabilities for all critical components, enhancing system reliability. + +ML-Framework Agnostic +..................... + +With FEDn, model updates are treated as black-box computations, meaning it can support any ML model type or framework. +This flexibility allows for out-of-the-box support for popular frameworks like Keras and PyTorch, making it a versatile tool for any machine learning project. + +Security +......... + +A key security feature of FEDn is its client protection capabilities, negating the need for clients to expose any ingress ports, +thus reducing potential security vulnerabilities. + +Event Tracking and Training progress +.................................... + +To ensure transparency and control over the learning process, +FEDn logs events in the federation and does real-time tracking of training progress. A flexible API lets the user define validation strategies locally on clients. +Data is logged as JSON to MongoDB, enabling users to create custom dashboards and visualizations easily. + +User Interfaces +............... + +FEDn offers a Flask-based Dashboard that allows users to monitor client model validations in real time. It also facilitates tracking client training time distributions +and key performance metrics for clients and combiners, providing a comprehensive view of the system’s operation and performance. + +FEDn also comes with an REST-API for integration with external dashboards and visualization tools, or integration with other systems. \ No newline at end of file diff --git a/docs/modules.rst b/docs/modules.rst new file mode 100644 index 000000000..c4dfb74d1 --- /dev/null +++ b/docs/modules.rst @@ -0,0 +1,7 @@ +API reference +============= + +.. toctree:: + :maxdepth: 4 + + fedn diff --git a/docs/quickstart.rst b/docs/quickstart.rst new file mode 100644 index 000000000..2b89ff165 --- /dev/null +++ b/docs/quickstart.rst @@ -0,0 +1,118 @@ +Quick Start +=========== + +Clone this repository, locate into it and start a pseudo-distributed FEDn network using docker-compose: + +.. code-block:: + + docker-compose up + + + +This will start up all neccecary components for a FEDn network, execept for the clients. + +.. warning:: + The FEDn network is configured to use a local Minio and MongoDB instances for storage. This is not suitable for production, but is fine for testing. + +.. note:: + You have the option to programmatically interact with the FEDn network using the Python APIClient, or you can use the Dashboard. In these Note sections we will use the APIClient. + Install the FEDn via pip: + + .. code-block:: bash + + $ pip install fedn + # or from source + $ cd fedn + $ pip install . + +Navigate to http://localhost:8090. You should see the FEDn Dashboard, asking you to upload a compute package. The compute package is a tarball of a project. +The project in turn implements the entrypoints used by clients to compute model updates and to validate a model. + +Locate into 'examples/mnist-pytorch'. + +Start by initializing a virtual enviroment with all of the required dependencies for this project. + +.. code-block:: python + + bin/init_venv.sh + +Now create the compute package and an initial model: + +.. code-block:: + + bin/build.sh + +Upload the generated files 'package.tgz' and 'seed.npz' in the FEDn Dashboard. + +.. note:: + Instead of uploading in the dashboard do: + + .. code:: python + + >>> from fedn import APIClient + >>> client = APIClient(host="localhost", port=8092) + >>> client.set_package("package.tgz", helper="pytorchhelper") + >>> client.set_initial_model("seed.npz") + +The next step is to configure and attach clients. For this we need to download data and make data partitions: + +Download the data: + +.. code-block:: + + bin/get_data + + +Split the data in 2 parts for the clients: + +.. code-block:: + + bin/split_data + +Data partitions will be generated in the folder 'data/clients'. + +Now navigate to http://localhost:8090/network and download the client config file. Place it in the example working directory. + +.. note:: + In the python enviroment you installed FEDn: + + .. code:: python + + >>> import yaml + >>> config = client.get_client_config(checksum=True) + >>> with open("client.yaml", "w") as f: + >>> f.write(yaml.dump(config)) + +To connect a client that uses the data partition 'data/clients/1/mnist.pt': + +.. code-block:: + + docker run \ + -v $PWD/client.yaml:/app/client.yaml \ + -v $PWD/data/clients/1:/var/data \ + -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ + --network=fedn_default \ + ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch run client -in client.yaml --name client1 + +.. note:: + If you are using the APIClient you must also start the training client via "docker run" command as above. + +You are now ready to start training the model at http://localhost:8090/control. + +.. note:: + In the python enviroment you installed FEDn you can start training via: + + .. code:: python + + >>> ... + >>> client.start_session(session_id="test-session", rounds=3) + # Wait for training to complete, when controller is idle: + >>> client.get_controller_status() + # Show model trail: + >>> client.get_model_trail() + # Show model performance: + >>> client.list_validations() + + Please see :py:mod:`fedn.network.api` for more details on the APIClient. + +To scale up the experiment, refer to the README at 'examples/mnist-pytorch' (or the corresponding Keras version), where we explain how to use docker-compose to automate deployment of several clients. diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 8e2bb7f78..355d86919 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -1,9 +1,9 @@ Tutorial: Compute Package ================================================ -This tutorial walks you through the key step done by the *model initiator* when setting up a federated project. -The code for this guideline has been taken from the mnist-keras example provided `here `_. -However, some modification to the code has been made for educational purposes. +This tutorial walks you through the design of a *compute package* for a FEDn client. The compute package is a tar.gz bundle of the code to be executed by each data-provider/client. +You will learn how to design the compute package and how to write the entry points for training and validation. Examples are provided for the Keras and PyTorch frameworks, which can be +found in the `examples `_. The compute package ----------------------------- @@ -14,23 +14,28 @@ The compute package :align: center The *compute package* is a tar.gz bundle of the code to be executed by each data-provider/client. -This package is uploaded to the Reducer upon initialization of the FEDN Network (along with the initial model). +This package is uploaded to the *Controller* upon initialization of the FEDN Network (along with the initial model). When a client connects to the network, it downloads and unpacks the package locally and are then ready to participate in training and/or validation. -The logic is illustrated in the above figure. When the `FEDn client `_. +The logic is illustrated in the above figure. When the :py:mod:`fedn.network.clients` recieves a model update request from the combiner, it calls upon a Dispatcher that looks up entry point definitions in the compute package. These entrypoints define commands executed by the client to update/train or validate a model. Designing the compute package ------------------------------ -We recommend to use the project structure followed by most example projects `here `_. -In the examples we have roughly the following structure: - -.. image:: img/tree_package.png - :alt: Project structure - :scale: 70 - :align: center +We recommend to use the project structure followed by most example `projects `_. +In the examples we have roughly the following file and folder structure: + +| project +| ├── client +| │ ├── entrypoint.py +| │ └── fedn.yaml +| ├── data +| │ └── mnist.npz +| ├── requirements.txt +| └── docker-compose.yml/Dockerfile +| The "client" folder is the *compute package* which will become a tar.gz bundle of the code to be executed by each data-provider/client. The entry points, mentioned above, are defined in the *fedn.yaml*: @@ -39,217 +44,266 @@ each data-provider/client. The entry points, mentioned above, are defined in the entry_points: train: - command: python train.py + command: python entrypoint.py validate: - command: python validate.py - -Where the training entry point has the following logical overview: - -.. image:: img/TrainSISO.png - :alt: Training entrypoint - :width: 100% - :align: center + command: python entrypoint.py The training entry point should be a single-input single-output program, taking as input a model update file -and writing a model update file (same file format). Staging and upload of these files are handled by the FEDn client. A helper class in the FEDn SDK handled the ML-framework +and writing a model update file (same file format). Staging and upload of these files are handled by the FEDn client. A helper class in the FEDn SDK handles the ML-framework specific file serialization and deserialization. The validation entry point acts very similar except we perform validation on the *model_in* and outputs a json containing a validation scores (see more below). -Upon training (model update) requests from the combiner, the client will download the latest (current) global model and *train.py* will be executed with this model update as input. After training / updating completes, the local client will capture the output file and send back the updated model to the combiner. For the local execution this means that the program (in this case train.py) will be executed as: +Upon training (model update) requests from the combiner, the client will download the latest (current) global model and *entrypoint.py train* will be executed with this model update as input. +After training / updating completes, the local client will capture the output file and send back the updated model to the combiner. +For the local execution this means that the program (in this case entrypoint.py) will be executed as: .. code-block:: python - python train.py model_in model_out + python entrypoint.py train in_model_path out_model_path -A typical *train.py* example can look like this: +A *entrypoint.py* example can look like this: .. code-block:: python - from __future__ import print_function - import sys - import yaml + import collections + import math + import os + + import docker + import fire + import torch + + from fedn.utils.helpers import get_helper, save_metadata, save_metrics - from data.read_data import read_data + HELPER_MODULE = 'pytorchhelper' + NUM_CLASSES = 10 + def _compile_model(): + """ Compile the pytorch model. - def train(model,data_path,settings): + :return: The compiled model. + :rtype: torch.nn.Module """ - Training function which will be called upon model update requests - from the combiner - - :param model: The latest global model, see '__main__' - :type model: User defined - :param data: Traning data - :type data: User defined - :param settings: Hyper-parameters settings - :type settings: dict - :return: Trained/updated model - :rtype: User defined + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = torch.nn.Linear(784, 64) + self.fc2 = torch.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + + def forward(self, x): + x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784))) + x = torch.nn.functional.dropout(x, p=0.5, training=self.training) + x = torch.nn.functional.relu(self.fc2(x)) + x = torch.nn.functional.log_softmax(self.fc3(x), dim=1) + return x + + # Return model + return Net() + + + def _load_data(data_path, is_train=True): + """ Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple """ - print("-- RUNNING TRAINING --", flush=True) + if data_path is None: + data = torch.load(_get_data_path()) + else: + data = torch.load(data_path) - #CODE TO READ DATA - - #EXAMPLE, SOME USER DEFINED FUNCION THAT READS THE TRAINING DATA - (x_train, y_train) = read_data(data_path, trainset=True) + if is_train: + X = data['x_train'] + y = data['y_train'] + else: + X = data['x_test'] + y = data['y_test'] - #CODE FOR START TRAINING - #EXAMPLE (Tensoflow) - model.fit(x_train, y_train, batch_size=settings['batch_size'], epochs=settings['epochs'], verbose=1) + # Normalize + X = X / 255 - print("-- TRAINING COMPLETED --", flush=True) - return model + return X, y - if __name__ == '__main__': - - #READ HYPER_PARAMETER SETTINGS FROM YAML FILE - with open('settings.yaml', 'r') as fh: - try: - settings = dict(yaml.safe_load(fh)) - except yaml.YAMLError as e: - raise(e) - - #CREATE THE SEED MODEL AND UPDATE WITH LATEST WEIGHTS - #EXAMPLE, USE KERAS HELPER IN FEDN SDK FOR READING WEIGHTS - from fedn.utils.kerashelper import KerasHelper - helper = KerasHelper() - weights = helper.load_model(sys.argv[1]) - - #EXAMPLE, A USER DEFINED FUNCTION THAT CONSTRUCTS THE MODEL, E.G THE ARCHITECTURE OF A NEURAL NETWORK - from models.model import create_seed_model - model = create_seed_model() - #EXAMPLE (HOW TO SET WEIGHTS OF A MODEL DIFFERS BETWEEN LIBRARIES) - model.set_weights(weights) - - #CALL TRAINING FUNCTION AND GET UPDATED MODEL - model = train(model,'../data/your_data.file',settings) - - #SAVE/SEND MODEL - #EXAMPLE, USING KERAS HELPER IN FEDN SDK - helper.save_model(model.get_weights(),sys.argv[2]) - + def _save_model(model, out_path): + """ Save model to disk. -The format of the input and output files (model updates) are dependent on the ML framework used. A `helper class `_. -defines serializaion and de-serialization of the model updates. -Observe that the functions `create_seed_model `_ -and `read_data `_ are implemented by the user, where the first function -constructs (compiles) and returns an untrained (seed) model. We then take this model and set the weights to be equal to the current global model recieved -from the commbiner. In the example above we use the Keras helper class to de-serialize those weights and the keras funcion *model.set_weights()* to set the seed model to be equal to the current model. -We then call the *train* function to first read the training data -(obs. the location of the data can differ depending on if you run the client in a native or containerized environment, in the latter case it's recommend to mount the data to the container, -the location should then be relative to the mount path) and then start the training. -In this example, training equals fitting the keras model, thus we call *model.fit()* fucntion. -The *settings.yaml* is for conveniance and is not required but contains the hyper-parameter settings for the local training as key/value pairs. + :param model: The model to save. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + weights = model.state_dict() + weights_np = collections.OrderedDict() + for w in weights: + weights_np[w] = weights[w].cpu().detach().numpy() + helper = get_helper(HELPER_MODULE) + helper.save(weights, out_path) -For validations it is a requirement that the output is valid json: -.. code-block:: python + def _load_model(model_path): + """ Load model from disk. - python validate.py model_in validation.json - -The Dahboard in the FEDn UI will plot any scalar metric in this json file, but you can include any type in the file assuming that it is valid json. These values can then be obtained (by an athorized user) from the MongoDB database (via Mongo Express, or any query interface or API). Typically, the actual model is defined in a small library, and does not depend on FEDn. An example (based on the keras case) of the *validate.py* is povided below: + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + helper = get_helper(HELPER_MODULE) + weights_np = helper.load(model_path) + weights = collections.OrderedDict() + for w in weights_np: + weights[w] = torch.tensor(weights_np[w]) + model = _compile_model() + model.load_state_dict(weights) + model.eval() + return model -.. code-block:: python - import sys - from data.read_data import read_data - import json - from sklearn import metrics - import os - import yaml - import numpy as np + def init_seed(out_path='seed.npz'): + """ Initialize seed model. - def validate(model,data): + :param out_path: The path to save the seed model to. + :type out_path: str """ - Validation function which will be called upon model validation requests - from the combiner. - - :param model: The latest global model, see '__main__' - :type model: User defined - :param data: The data used for validation, could include both training and test/validation data - :type data: User defined - :return: Model scores from the validation - :rtype: dict + # Init and save + model = _compile_model() + _save_model(model, out_path) + + + def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """ Train model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float """ - print("-- RUNNING VALIDATION --", flush=True) - - #CODE TO READ DATA - - #EXAMPLE - (x_train, y_train) = read_data(data, trainset=True) - - #EXAMPLE - (x_test, y_test) = read_data(data, trainset=False) - - try: - #CODE HERE FOR OBTAINING VALIDATION SCORES - - #EXAMPLE - model_score = model.evaluate(x_train, y_train, verbose=0) - model_score_test = model.evaluate(x_test, y_test, verbose=0) - y_pred = model.predict(x_test) - y_pred = np.argmax(y_pred, axis=1) - clf_report = metrics.classification_report(y_test.argmax(axis=-1),y_pred) - - except Exception as e: - print("failed to validate the model {}".format(e),flush=True) - raise - - #PUT SCORES AS VALUES FOR CORRESPONDING KEYS (CHANGE VARIABLES): + # Load data + x_train, y_train = _load_data(data_path) + + # Load model + model = _load_model(in_model_path) + + # Train + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + n_batches = int(math.ceil(len(x_train) / batch_size)) + criterion = torch.nn.NLLLoss() + for e in range(epochs): # epoch loop + for b in range(n_batches): # batch loop + # Retrieve current batch + batch_x = x_train[b * batch_size:(b + 1) * batch_size] + batch_y = y_train[b * batch_size:(b + 1) * batch_size] + # Train on batch + optimizer.zero_grad() + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + # Log + if b % 100 == 0: + print( + f"Epoch {e}/{epochs-1} | Batch: {b}/{n_batches-1} | Loss: {loss.item()}") + + # Metadata needed for aggregation server side + metadata = { + 'num_examples': len(x_train), + 'batch_size': batch_size, + 'epochs': epochs, + 'lr': lr + } + + # Save JSON metadata file + save_metadata(metadata, out_model_path) + + # Save model update + _save_model(model, out_model_path) + + + def validate(in_model_path, out_json_path, data_path=None): + """ Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + # Load data + x_train, y_train = _load_data(data_path) + x_test, y_test = _load_data(data_path, is_train=False) + + # Load model + model = _load_model(in_model_path) + + # Evaluate + criterion = torch.nn.NLLLoss() + with torch.no_grad(): + train_out = model(x_train) + training_loss = criterion(train_out, y_train) + training_accuracy = torch.sum(torch.argmax( + train_out, dim=1) == y_train) / len(train_out) + test_out = model(x_test) + test_loss = criterion(test_out, y_test) + test_accuracy = torch.sum(torch.argmax( + test_out, dim=1) == y_test) / len(test_out) + + # JSON schema report = { - "classification_report": clf_report, - "training_loss": model_score[0], - "training_accuracy": model_score[1], - "test_loss": model_score_test[0], - "test_accuracy": model_score_test[1], - } + "training_loss": training_loss.item(), + "training_accuracy": training_accuracy.item(), + "test_loss": test_loss.item(), + "test_accuracy": test_accuracy.item(), + } - print("-- VALIDATION COMPLETE! --", flush=True) - return report + # Save JSON + save_metrics(report, out_json_path) - if __name__ == '__main__': - #READS THE LATEST WEIGHTS FROM GLOBAL MODEL (COMBINER) - - from fedn.utils.kerashelper import KerasHelper - helper = KerasHelper() - weights = helper.load_model(sys.argv[1]) - - #CREATE THE SEED MODEL AND UPDATE WITH LATEST WEIGHTS - from models.model import create_seed_model - model = create_seed_model() - #EXAMPLE (HOW TO SET WEIGHTS OF A MODEL DIFFERS BETWEEN LIBRARIES) - model.set_weights(weights) + if __name__ == '__main__': + fire.Fire({ + 'init_seed': init_seed, + 'train': train, + 'validate': validate, + # '_get_data_path': _get_data_path, # for testing + }) - #START VALIDATION - report = validate(model,'../data/your_data.file') - #SAVE/SEND SCORE REPORT - with open(sys.argv[2],"w") as fh: - fh.write(json.dumps(report)) -As demonstrated in the code above, the structure is very similar to *train.py*. The main difference is that we perform validation of a current model provided by the combiner instead of training. Again, the *read_data* function is defined by the user. Once, we have optained a validation -*report* as a dictionary we can dump as json (required). Observe that the key/values are arbitrary. +The format of the input and output files (model updates) are dependent on the ML framework used. A helper instance :py:mod:`fedn.utils.plugins.pytorchhelper` is used to handle the serialization and deserialization of the model updates. +The first function (_compile_model) is used to define the model architecture and creates an initial model (which is then used by _init_seed). The second function (_load_data) is used to read the data (train and test) from disk. +The third function (_save_model) is used to save the model to disk using the pytorch helper module :py:mod:`fedn.utils.plugins.pytorchhelper`. The fourth function (_load_model) is used to load the model from disk, again +using the pytorch helper module. The fifth function (_init_seed) is used to initialize the seed model. The sixth function (_train) is used to train the model, observe the two first arguments which will be set by the FEDn client. +The seventh function (_validate) is used to validate the model, again observe the two first arguments which will be set by the FEDn client. -For the initialization of the Reducer, both the compute package and an initial model (weights) are required as individual files. To obtain the initial weights file we can use the fedn helpers to save the seed model to an output file (*init_model.py*): - -.. code-block:: python +Finally, we use the python package fire to create a command line interface for the entry points. This is not required but convenient. - from fedn.utils.kerashelper import KerasHelper - from models.mnist_model import create_seed_model +For validations it is a requirement that the output is saved in a valid json format: - if __name__ == '__main__': +.. code-block:: python - #CREATE INITIAL MODEL, UPLOAD TO REDUCER - model = create_seed_model() - outfile_name = "../initial_model/initial_model.npz" + python entrypoint.py validate in_model_path out_json_path + +In the code example we use the helper function :py:meth:`fedn.utils.helpers.save_metrics` to save the validation scores as a json file. - weights = model.get_weights() - helper = KerasHelper() - helper.save_model(weights, outfile_name) +The Dahboard in the FEDn UI will plot any scalar metric in this json file, but you can include any type in the file assuming that it is valid json. These values can then be obtained (by an athorized user) from the MongoDB database or using the :py:mod:`fedn.network.api.client`. -Which will be saved into the *initial_model* folder for convenience. Of course this file can also be a pretrained seed model. +Packaging for distribution +-------------------------- For the compute package we need to compress the *client* folder as .tar.gz file. E.g. using: .. code-block:: bash @@ -257,25 +311,39 @@ For the compute package we need to compress the *client* folder as .tar.gz file. tar -czvf package.tar.gz client -More on Data access -------------------- +This file can then be uploaded to the FEDn network using the FEDn UI or the :py:mod:`fedn.network.api.client`. + + +More on local data access +------------------------- There are many possible ways to interact with the local dataset. In principle, the only requirement is that the train and validate endpoints are able to correctly -read and use the data. In practice, it is then necessary to make some assumption on the local environemnt when writing train.py and validate.py. This is best explained -by looking at the code above. Here we assume that the dataset is present in a file called "your_data.file" in a folder "data" one level up in the file hierarchy relative to -the exection of train.py. Then, independent on the preferred way to run the client (native, Docker, K8s etc) this structure needs to be maintained for this particular +read and use the data. In practice, it is then necessary to make some assumption on the local environemnt when writing entrypoint.py. This is best explained +by looking at the code above. Here we assume that the dataset is present in a file called "mnist.npz" in a folder "data" one level up in the file hierarchy relative to +the exection of entrypoint.py. Then, independent on the preferred way to run the client (native, Docker, K8s etc) this structure needs to be maintained for this particular compute package. Note however, that there are many ways to accompish this on a local operational level. Running the client ------------------ -We recommend you to test your code before running the client. For example, you can simply test *train.py* and *validate.py* by: +We recommend you to test your code before running the client. For example, you can simply test *train* and *validate* by: -.. code-block:: python +.. code-block:: bash + + python entrypoint.py train ../seed.npz ../model_update.npz --data_path ../data/mnist.npz + python entrypoint.py validate ../model_update.npz ../validation.json --data_path ../data/mnist.npz - python train.py ../initial_model/initial_model.npz +Once everything works as expected you can start the federated network, upload the tar.gz compute package and the initial model. +Finally connect a client to the network: + +.. code-block:: bash -Once everything works as expected you can start the Reducer, upload the tar.gz compute package and the initial weights, followed by starting one or many combiners. -Finally connect a client to the network. Instructions for how to connect clients can be found in the `examples `_. + docker run \ + -v $PWD/client.yaml:/app/client.yaml \ + -v $PWD/data/clients/1:/var/data \ + -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ + --network=fedn_default \ + ghcr.io/scaleoutsystems/fedn/fedn:master-mnist-pytorch run client -in client.yaml --name client1 +The container image "ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch" is a pre-built image with the FEDn client and the PyTorch framework installed. diff --git a/examples/mnist-pytorch/client/entrypoint b/examples/mnist-pytorch/client/entrypoint index 5b671f4b9..8d7953b59 100755 --- a/examples/mnist-pytorch/client/entrypoint +++ b/examples/mnist-pytorch/client/entrypoint @@ -25,7 +25,11 @@ def _get_data_path(): def _compile_model(): - # Define model + """ Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() @@ -45,7 +49,15 @@ def _compile_model(): def _load_data(data_path, is_train=True): - # Load data + """ Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ if data_path is None: data = torch.load(_get_data_path()) else: @@ -65,6 +77,13 @@ def _load_data(data_path, is_train=True): def _save_model(model, out_path): + """ Save model to disk. + + :param model: The model to save. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ weights = model.state_dict() weights_np = collections.OrderedDict() for w in weights: @@ -74,6 +93,13 @@ def _save_model(model, out_path): def _load_model(model_path): + """ Load model from disk. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ helper = get_helper(HELPER_MODULE) weights_np = helper.load(model_path) weights = collections.OrderedDict() @@ -86,12 +112,32 @@ def _load_model(model_path): def init_seed(out_path='seed.npz'): + """ Initialize seed model. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ # Init and save model = _compile_model() _save_model(model, out_path) def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """ Train model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ # Load data x_train, y_train = _load_data(data_path) @@ -134,6 +180,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1 def validate(in_model_path, out_json_path, data_path=None): + """ Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ # Load data x_train, y_train = _load_data(data_path) x_test, y_test = _load_data(data_path, is_train=False) diff --git a/fedn/__init__.py b/fedn/__init__.py index 52ce8c9c3..31be09d81 100644 --- a/fedn/__init__.py +++ b/fedn/__init__.py @@ -1,3 +1,3 @@ -# -# Scaleout Systems AB -# __author__ = 'Morgan Ekmefjord morgan@scaleout.se' +"""The fedn package.""" + +# flake8: noqa diff --git a/fedn/cli/__init__.py b/fedn/cli/__init__.py index 13c9b1c51..840d4252b 100644 --- a/fedn/cli/__init__.py +++ b/fedn/cli/__init__.py @@ -1,3 +1,2 @@ -from .control_cmd import control_cmd # noqa: F401 from .main import main # noqa: F401 from .run_cmd import run_cmd # noqa: F401 diff --git a/fedn/cli/control_cmd.py b/fedn/cli/control_cmd.py deleted file mode 100644 index 9305c7015..000000000 --- a/fedn/cli/control_cmd.py +++ /dev/null @@ -1,125 +0,0 @@ -import os -from datetime import datetime - -import click - -from fedn.common.control.package import Package, PackageRuntime - -from .main import main - - -@main.group('control') -@click.pass_context -def control_cmd(ctx): - """ - - :param ctx: - """ - # if daemon: - # print('{} NYI should run as daemon...'.format(__file__)) - pass - - -@control_cmd.command('package') -@click.option('-r', '--reducer', required=False) -@click.option('-p', '--port', required=False) -@click.option('-t', '--token', required=False) -@click.option('-n', '--name', required=False, default=None) -@click.option('-u', '--upload', required=False, default=None) -@click.option('-v', '--validate', required=False, default=False) -@click.option('-d', '--cwd', required=False, default=None) -@click.pass_context -def package_cmd(ctx, reducer, port, token, name, upload, validate, cwd): - """ - - :param ctx: - :param reducer: - :param port: - :param token: - :param name: - :param upload: - :param validate: - :param cwd: - """ - if not cwd: - cwd = os.getcwd() - - print("CONTROL: Bundling {} dir for distribution. Please wait for operation to complete..".format(cwd)) - - if not name: - name = str(os.path.basename(cwd)) + '-' + \ - datetime.today().strftime('%Y-%m-%d-%H%M%S') - - config = {'host': reducer, 'port': port, 'token': token, 'name': name, - 'cwd': cwd} - - package = Package(config) - - print("CONTROL: Bundling package..") - package.package(validate=validate) - print("CONTROL: Bundle completed\nCONTROL: Resulted in: {}.tar.gz".format(name)) - if upload: - print("CONTROL: started upload") - package.upload() - print("CONTROL: upload finished!") - else: - print("CONTROL: set --upload flag along with --reducer and --port if you want to upload directly from client.") - - -@control_cmd.command('unpack') -@click.option('-r', '--reducer', required=True) -@click.option('-p', '--port', required=True) -@click.option('-t', '--token', required=True) -@click.option('-n', '--name', required=False, default=None) -@click.option('-d', '--download', required=False, default=None) -@click.option('-v', '--validate', required=False, default=False) -@click.option('-c', '--cwd', required=False, default=None) -@click.pass_context -def unpack_cmd(ctx, reducer, port, token, name, download, validate, cwd): - """ - - :param ctx: - :param reducer: - :param port: - :param token: - :param name: - :param download: - :param validate: - :param cwd: - """ - if not cwd: - cwd = os.getcwd() - - # config = {'host': reducer, 'port': port, 'token': token, 'name': name, - # 'cwd': cwd} - - package = PackageRuntime(cwd, os.path.join(cwd, 'client')) - package.download(reducer, port, token) - package.unpack() - - -@control_cmd.command('template') -@click.pass_context -def template_cmd(ctx): - """ - - :param ctx: - """ - print("TODO: generate template") - pass - - -@control_cmd.command('start') -@click.option('-r', '--reducer', required=True) -@click.option('-p', '--port', required=True) -@click.option('-t', '--token', required=True) -@click.pass_context -def control_cmd(ctx, reducer, port, token): - """ - - :param ctx: - :param reducer: - :param port: - :param token: - """ - pass diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index bf49cd1f0..119b8de45 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -145,7 +145,7 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa client.run() -@run_cmd.command('reducer') +@run_cmd.command('dashboard') @click.option('-h', '--host', required=False) @click.option('-p', '--port', required=False, default='8090', show_default=True) @click.option('-k', '--secret-key', required=False, help='Set secret key to enable jwt token authentication.') @@ -154,15 +154,16 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa @click.option('-in', '--init', required=True, default=None, help='Set to a filename to (re)init reducer state from file.') @click.pass_context -def reducer_cmd(ctx, host, port, secret_key, local_package, name, init): - """ - - :param ctx: - :param discoverhost: - :param discoverport: - :param secret_key: - :param name: - :param init: +def dashboard_cmd(ctx, host, port, secret_key, local_package, name, init): + """ Start the dashboard service. + + :param ctx: Click context. + :param discoverhost: Hostname for discovery services (dashboard). + :param discoverport: Port for discovery services (dashboard). + :param secret_key: Set secret key to enable jwt token authentication. + :param local_package: Enable use of local compute package. + :param name: Set service name. + :param init: Set to a filename to (re)init config state from file. """ remote = False if local_package else True config = {'host': host, 'port': port, 'secret_key': secret_key, @@ -189,7 +190,7 @@ def reducer_cmd(ctx, host, port, secret_key, local_package, name, init): statestore_config = fedn_config['statestore'] if statestore_config['type'] == 'MongoDB': statestore = MongoStateStore( - network_id, statestore_config['mongo_config'], defaults=config['init']) + network_id, statestore_config['mongo_config'], fedn_config['storage']) else: print("Unsupported statestore type, exiting. ", flush=True) exit(-1) diff --git a/fedn/fedn/__init__.py b/fedn/fedn/__init__.py index f04a9cd80..7e3df239e 100644 --- a/fedn/fedn/__init__.py +++ b/fedn/fedn/__init__.py @@ -2,6 +2,11 @@ import os from os.path import basename, dirname, isfile +from fedn.network.api.client import APIClient + +# flake8: noqa + + modules = glob.glob(dirname(__file__) + "/*.py") __all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] diff --git a/fedn/fedn/common/config.py b/fedn/fedn/common/config.py new file mode 100644 index 000000000..f6c827d0d --- /dev/null +++ b/fedn/fedn/common/config.py @@ -0,0 +1,94 @@ +import os + +import yaml + +global STATESTORE_CONFIG +global MODELSTORAGE_CONFIG + + +def get_environment_config(): + """ Get the configuration from environment variables. + """ + global STATESTORE_CONFIG + global MODELSTORAGE_CONFIG + + STATESTORE_CONFIG = os.environ.get('STATESTORE_CONFIG', + '/workspaces/fedn/config/settings-reducer.yaml.template') + MODELSTORAGE_CONFIG = os.environ.get('MODELSTORAGE_CONFIG', + '/workspaces/fedn/config/settings-reducer.yaml.template') + + +def get_statestore_config(file=None): + """ Get the statestore configuration from file. + + :param file: The statestore configuration file (yaml) path (optional). + :type file: str + :return: The statestore configuration as a dict. + :rtype: dict + """ + if file is None: + get_environment_config() + file = STATESTORE_CONFIG + with open(file, 'r') as config_file: + try: + settings = dict(yaml.safe_load(config_file)) + except yaml.YAMLError as e: + raise (e) + return settings["statestore"] + + +def get_modelstorage_config(file=None): + """ Get the model storage configuration from file. + + :param file: The model storage configuration file (yaml) path (optional). + :type file: str + :return: The model storage configuration as a dict. + :rtype: dict + """ + if file is None: + get_environment_config() + file = MODELSTORAGE_CONFIG + with open(file, 'r') as config_file: + try: + settings = dict(yaml.safe_load(config_file)) + except yaml.YAMLError as e: + raise (e) + return settings["storage"] + + +def get_network_config(file=None): + """ Get the network configuration from file. + + :param file: The network configuration file (yaml) path (optional). + :type file: str + :return: The network id. + :rtype: str + """ + if file is None: + get_environment_config() + file = STATESTORE_CONFIG + with open(file, 'r') as config_file: + try: + settings = dict(yaml.safe_load(config_file)) + except yaml.YAMLError as e: + raise (e) + return settings["network_id"] + + +def get_controller_config(file=None): + """ Get the controller configuration from file. + + :param file: The controller configuration file (yaml) path (optional). + :type file: str + :return: The controller configuration as a dict. + :rtype: dict + """ + if file is None: + get_environment_config() + file = STATESTORE_CONFIG + with open(file, 'r') as config_file: + try: + settings = dict(yaml.safe_load(config_file)) + except yaml.YAMLError as e: + raise (e) + return settings["controller"] diff --git a/fedn/fedn/common/control/__init__.py b/fedn/fedn/common/control/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/fedn/fedn/common/net/connect.py b/fedn/fedn/common/net/connect.py deleted file mode 100644 index 0a7cf4051..000000000 --- a/fedn/fedn/common/net/connect.py +++ /dev/null @@ -1,175 +0,0 @@ -import enum - -import requests as r - - -class State(enum.Enum): - Disconnected = 0 - Connected = 1 - Error = 2 - - -class Status(enum.Enum): - Unassigned = 0 - Assigned = 1 - TryAgain = 2 - UnAuthorized = 3 - UnMatchedConfig = 4 - - -class ConnectorClient: - """ - Connector for assigning client to a combiner in the FEDn network. - """ - - def __init__(self, host, port, token, name, remote_package, force_ssl=False, verify=False, combiner=None, id=None): - - self.host = host - self.port = port - self.token = token - self.name = name - self.verify = verify - self.preferred_combiner = combiner - self.id = id - self.package = 'remote' if remote_package else 'local' - - # for https we assume a an ingress handles permanent redirect (308) - if force_ssl: - self.prefix = "https://" - else: - self.prefix = "http://" - if self.port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) - else: - self.connect_string = "{}{}".format( - self.prefix, self.host) - - print("\n\nsetting the connection string to {}\n\n".format( - self.connect_string), flush=True) - - def state(self): - """ - - :return: Connector State - """ - return self.state - - def assign(self): - """ - Connect client to FEDn network discovery service, ask for combiner assignment. - - :return: Tuple with assingment status, combiner connection information - if sucessful, else None. - :rtype: Status, json - """ - try: - retval = None - if self.preferred_combiner: - retval = r.get("{}?name={}&combiner={}".format(self.connect_string + '/assign', self.name, - self.preferred_combiner), - verify=self.verify, - allow_redirects=True, - headers={'Authorization': 'Token {}'.format(self.token)}) - else: - retval = r.get("{}?name={}".format(self.connect_string + '/assign', self.name), - verify=self.verify, - allow_redirects=True, - headers={'Authorization': 'Token {}'.format(self.token)}) - except Exception as e: - print('***** {}'.format(e), flush=True) - return Status.Unassigned, {} - - if retval.status_code == 401: - reason = "Unauthorized connection to reducer, make sure the correct token is set" - return Status.UnAuthorized, reason - - reducer_package = retval.json()['package'] - if reducer_package != self.package: - reason = "Unmatched config of compute package between client and reducer.\n" +\ - "Reducer uses {} package and client uses {}.".format( - reducer_package, self.package) - return Status.UnMatchedConfig, reason - - if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - if 'msg' in retval.json(): - reason = retval.json()['msg'] - else: - reason = "Reducer was not ready. Try again later." - - return Status.TryAgain, reason - - return Status.Assigned, retval.json() - - return Status.Unassigned, None - - -class ConnectorCombiner: - """ - Connector for annnouncing combiner to the FEDn network. - """ - - def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, verify=False): - - self.host = host - self.fqdn = fqdn - self.port = port - self.myhost = myhost - self.myport = myport - self.token = token - self.name = name - self.secure = secure - self.verify = verify - - # for https we assume a an ingress handles permanent redirect (308) - self.prefix = "http://" - if port: - self.connect_string = "{}{}:{}".format( - self.prefix, self.host, self.port) - else: - self.connect_string = "{}{}".format( - self.prefix, self.host) - - print("\n\nsetting the connection string to {}\n\n".format( - self.connect_string), flush=True) - - def state(self): - """ - - :return: Combiner State - """ - return self.state - - def announce(self): - """ - Announce combiner to FEDn network via discovery service. - - :return: Tuple with announcement Status, FEDn network configuration - if sucessful, else None. - :rtype: Staus, json - """ - try: - retval = r.get("{}?name={}&address={}&fqdn={}&port={}&secure={}".format( - self.connect_string + '/add', - self.name, - self.myhost, - self.fqdn, - self.myport, - self.secure), - verify=self.verify, - headers={'Authorization': 'Token {}'.format(self.token)}) - except Exception: - return Status.Unassigned, {} - - if retval.status_code == 401: - reason = "Unauthorized connection to reducer, make sure the correct token is set" - return Status.UnAuthorized, reason - - if retval.status_code >= 200 and retval.status_code < 204: - if retval.json()['status'] == 'retry': - reason = "Reducer was not ready. Try again later." - return Status.TryAgain, reason - return Status.Assigned, retval.json() - - return Status.Unassigned, None diff --git a/fedn/fedn/common/net/grpc/fedn.proto b/fedn/fedn/common/net/grpc/fedn.proto index dca66fe20..ff0ee293c 100644 --- a/fedn/fedn/common/net/grpc/fedn.proto +++ b/fedn/fedn/common/net/grpc/fedn.proto @@ -4,7 +4,6 @@ package grpc; message Response { Client sender = 1; - //string client = 1; string response = 2; } @@ -19,7 +18,6 @@ enum StatusType { message Status { Client sender = 1; - //string client = 1; string status = 2; enum LogLevel { @@ -95,6 +93,7 @@ enum ModelStatus { IN_PROGRESS_OK = 2; FAILED = 3; } + message ModelRequest { Client sender = 1; Client receiver = 2; @@ -204,7 +203,8 @@ message ReportResponse { service Control { rpc Start(ControlRequest) returns (ControlResponse); rpc Stop(ControlRequest) returns (ControlResponse); - rpc Configure(ControlRequest) returns (ReportResponse); + rpc Configure(ControlRequest) returns (ReportResponse); + rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse); rpc Report(ControlRequest) returns (ReportResponse); } diff --git a/fedn/fedn/common/net/grpc/fedn_pb2.py b/fedn/fedn/common/net/grpc/fedn_pb2.py index 19cd119c6..fa4fbb16d 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2.py @@ -2,19 +2,20 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: fedn/common/net/grpc/fedn.proto """Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import enum_type_wrapper - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x66\x65\x64n/common/net/grpc/fedn.proto\x12\x04grpc\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\x8c\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.grpc.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.grpc.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xab\x01\n\x12ModelUpdateRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xaf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xc5\x01\n\x16ModelValidationRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x14\n\x0cis_inference\x18\x08 \x01(\x08\"\xa8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.grpc.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.grpc.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"R\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\r.grpc.Channel\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.grpc.Client\"0\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.grpc.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.grpc.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"R\n\x0eReportResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.grpc.ConnectionStatus*\x84\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\r\n\tINFERENCE\x10\x05*\x86\x01\n\x07\x43hannel\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x19\n\x15MODEL_UPDATE_REQUESTS\x10\x01\x12\x11\n\rMODEL_UPDATES\x10\x02\x12\x1d\n\x19MODEL_VALIDATION_REQUESTS\x10\x03\x12\x15\n\x11MODEL_VALIDATIONS\x10\x04\x12\n\n\x06STATUS\x10\x05*F\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse0\x01\x32\xe3\x01\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x33\n\x04Stop\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x37\n\tConfigure\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse\x12\x34\n\x06Report\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.grpc.GetGlobalModelRequest\x1a\x1c.grpc.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x0c.grpc.Status0\x01\x12*\n\nSendStatus\x12\x0c.grpc.Status\x1a\x0e.grpc.Response\x12?\n\x11ListActiveClients\x12\x18.grpc.ListClientsRequest\x1a\x10.grpc.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.grpc.ConnectionRequest\x1a\x18.grpc.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.grpc.Heartbeat\x1a\x0e.grpc.Response\x12\x37\n\x0eReassignClient\x12\x15.grpc.ReassignRequest\x1a\x0e.grpc.Response\x12\x39\n\x0fReconnectClient\x12\x16.grpc.ReconnectRequest\x1a\x0e.grpc.Response2\xda\x04\n\x08\x43ombiner\x12T\n\x18ModelUpdateRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x18.grpc.ModelUpdateRequest0\x01\x12\x46\n\x11ModelUpdateStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x11.grpc.ModelUpdate0\x01\x12\\\n\x1cModelValidationRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x1c.grpc.ModelValidationRequest0\x01\x12N\n\x15ModelValidationStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x15.grpc.ModelValidation0\x01\x12\x42\n\x16SendModelUpdateRequest\x12\x18.grpc.ModelUpdateRequest\x1a\x0e.grpc.Response\x12\x34\n\x0fSendModelUpdate\x12\x11.grpc.ModelUpdate\x1a\x0e.grpc.Response\x12J\n\x1aSendModelValidationRequest\x12\x1c.grpc.ModelValidationRequest\x1a\x0e.grpc.Response\x12<\n\x13SendModelValidation\x12\x15.grpc.ModelValidation\x1a\x0e.grpc.Responseb\x06proto3') # noqa: E501 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x66\x65\x64n/common/net/grpc/fedn.proto\x12\x04grpc\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\x8c\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.grpc.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.grpc.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xab\x01\n\x12ModelUpdateRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xaf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xc5\x01\n\x16ModelValidationRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x14\n\x0cis_inference\x18\x08 \x01(\x08\"\xa8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.grpc.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.grpc.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"R\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\r.grpc.Channel\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.grpc.Client\"0\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.grpc.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.grpc.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"R\n\x0eReportResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.grpc.ConnectionStatus*\x84\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\r\n\tINFERENCE\x10\x05*\x86\x01\n\x07\x43hannel\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x19\n\x15MODEL_UPDATE_REQUESTS\x10\x01\x12\x11\n\rMODEL_UPDATES\x10\x02\x12\x1d\n\x19MODEL_VALIDATION_REQUESTS\x10\x03\x12\x15\n\x11MODEL_VALIDATIONS\x10\x04\x12\n\n\x06STATUS\x10\x05*F\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse0\x01\x32\xa9\x02\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x33\n\x04Stop\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x37\n\tConfigure\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x34\n\x06Report\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.grpc.GetGlobalModelRequest\x1a\x1c.grpc.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x0c.grpc.Status0\x01\x12*\n\nSendStatus\x12\x0c.grpc.Status\x1a\x0e.grpc.Response\x12?\n\x11ListActiveClients\x12\x18.grpc.ListClientsRequest\x1a\x10.grpc.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.grpc.ConnectionRequest\x1a\x18.grpc.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.grpc.Heartbeat\x1a\x0e.grpc.Response\x12\x37\n\x0eReassignClient\x12\x15.grpc.ReassignRequest\x1a\x0e.grpc.Response\x12\x39\n\x0fReconnectClient\x12\x16.grpc.ReconnectRequest\x1a\x0e.grpc.Response2\xda\x04\n\x08\x43ombiner\x12T\n\x18ModelUpdateRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x18.grpc.ModelUpdateRequest0\x01\x12\x46\n\x11ModelUpdateStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x11.grpc.ModelUpdate0\x01\x12\\\n\x1cModelValidationRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x1c.grpc.ModelValidationRequest0\x01\x12N\n\x15ModelValidationStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x15.grpc.ModelValidation0\x01\x12\x42\n\x16SendModelUpdateRequest\x12\x18.grpc.ModelUpdateRequest\x1a\x0e.grpc.Response\x12\x34\n\x0fSendModelUpdate\x12\x11.grpc.ModelUpdate\x1a\x0e.grpc.Response\x12J\n\x1aSendModelValidationRequest\x12\x1c.grpc.ModelValidationRequest\x1a\x0e.grpc.Response\x12<\n\x13SendModelValidation\x12\x15.grpc.ModelValidation\x1a\x0e.grpc.Responseb\x06proto3') _STATUSTYPE = DESCRIPTOR.enum_types_by_name['StatusType'] StatusType = enum_type_wrapper.EnumTypeWrapper(_STATUSTYPE) @@ -84,164 +85,164 @@ _CONNECTIONRESPONSE = DESCRIPTOR.message_types_by_name['ConnectionResponse'] _STATUS_LOGLEVEL = _STATUS.enum_types_by_name['LogLevel'] Response = _reflection.GeneratedProtocolMessageType('Response', (_message.Message,), { - 'DESCRIPTOR': _RESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Response) -}) + 'DESCRIPTOR' : _RESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Response) + }) _sym_db.RegisterMessage(Response) Status = _reflection.GeneratedProtocolMessageType('Status', (_message.Message,), { - 'DESCRIPTOR': _STATUS, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Status) -}) + 'DESCRIPTOR' : _STATUS, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Status) + }) _sym_db.RegisterMessage(Status) ModelUpdateRequest = _reflection.GeneratedProtocolMessageType('ModelUpdateRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELUPDATEREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelUpdateRequest) -}) + 'DESCRIPTOR' : _MODELUPDATEREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelUpdateRequest) + }) _sym_db.RegisterMessage(ModelUpdateRequest) ModelUpdate = _reflection.GeneratedProtocolMessageType('ModelUpdate', (_message.Message,), { - 'DESCRIPTOR': _MODELUPDATE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelUpdate) -}) + 'DESCRIPTOR' : _MODELUPDATE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelUpdate) + }) _sym_db.RegisterMessage(ModelUpdate) ModelValidationRequest = _reflection.GeneratedProtocolMessageType('ModelValidationRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELVALIDATIONREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelValidationRequest) -}) + 'DESCRIPTOR' : _MODELVALIDATIONREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelValidationRequest) + }) _sym_db.RegisterMessage(ModelValidationRequest) ModelValidation = _reflection.GeneratedProtocolMessageType('ModelValidation', (_message.Message,), { - 'DESCRIPTOR': _MODELVALIDATION, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelValidation) -}) + 'DESCRIPTOR' : _MODELVALIDATION, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelValidation) + }) _sym_db.RegisterMessage(ModelValidation) ModelRequest = _reflection.GeneratedProtocolMessageType('ModelRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelRequest) -}) + 'DESCRIPTOR' : _MODELREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelRequest) + }) _sym_db.RegisterMessage(ModelRequest) ModelResponse = _reflection.GeneratedProtocolMessageType('ModelResponse', (_message.Message,), { - 'DESCRIPTOR': _MODELRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelResponse) -}) + 'DESCRIPTOR' : _MODELRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelResponse) + }) _sym_db.RegisterMessage(ModelResponse) GetGlobalModelRequest = _reflection.GeneratedProtocolMessageType('GetGlobalModelRequest', (_message.Message,), { - 'DESCRIPTOR': _GETGLOBALMODELREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelRequest) -}) + 'DESCRIPTOR' : _GETGLOBALMODELREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelRequest) + }) _sym_db.RegisterMessage(GetGlobalModelRequest) GetGlobalModelResponse = _reflection.GeneratedProtocolMessageType('GetGlobalModelResponse', (_message.Message,), { - 'DESCRIPTOR': _GETGLOBALMODELRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelResponse) -}) + 'DESCRIPTOR' : _GETGLOBALMODELRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelResponse) + }) _sym_db.RegisterMessage(GetGlobalModelResponse) Heartbeat = _reflection.GeneratedProtocolMessageType('Heartbeat', (_message.Message,), { - 'DESCRIPTOR': _HEARTBEAT, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Heartbeat) -}) + 'DESCRIPTOR' : _HEARTBEAT, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Heartbeat) + }) _sym_db.RegisterMessage(Heartbeat) ClientAvailableMessage = _reflection.GeneratedProtocolMessageType('ClientAvailableMessage', (_message.Message,), { - 'DESCRIPTOR': _CLIENTAVAILABLEMESSAGE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ClientAvailableMessage) -}) + 'DESCRIPTOR' : _CLIENTAVAILABLEMESSAGE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ClientAvailableMessage) + }) _sym_db.RegisterMessage(ClientAvailableMessage) ListClientsRequest = _reflection.GeneratedProtocolMessageType('ListClientsRequest', (_message.Message,), { - 'DESCRIPTOR': _LISTCLIENTSREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ListClientsRequest) -}) + 'DESCRIPTOR' : _LISTCLIENTSREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ListClientsRequest) + }) _sym_db.RegisterMessage(ListClientsRequest) ClientList = _reflection.GeneratedProtocolMessageType('ClientList', (_message.Message,), { - 'DESCRIPTOR': _CLIENTLIST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ClientList) -}) + 'DESCRIPTOR' : _CLIENTLIST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ClientList) + }) _sym_db.RegisterMessage(ClientList) Client = _reflection.GeneratedProtocolMessageType('Client', (_message.Message,), { - 'DESCRIPTOR': _CLIENT, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Client) -}) + 'DESCRIPTOR' : _CLIENT, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Client) + }) _sym_db.RegisterMessage(Client) ReassignRequest = _reflection.GeneratedProtocolMessageType('ReassignRequest', (_message.Message,), { - 'DESCRIPTOR': _REASSIGNREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReassignRequest) -}) + 'DESCRIPTOR' : _REASSIGNREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReassignRequest) + }) _sym_db.RegisterMessage(ReassignRequest) ReconnectRequest = _reflection.GeneratedProtocolMessageType('ReconnectRequest', (_message.Message,), { - 'DESCRIPTOR': _RECONNECTREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReconnectRequest) -}) + 'DESCRIPTOR' : _RECONNECTREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReconnectRequest) + }) _sym_db.RegisterMessage(ReconnectRequest) Parameter = _reflection.GeneratedProtocolMessageType('Parameter', (_message.Message,), { - 'DESCRIPTOR': _PARAMETER, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Parameter) -}) + 'DESCRIPTOR' : _PARAMETER, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Parameter) + }) _sym_db.RegisterMessage(Parameter) ControlRequest = _reflection.GeneratedProtocolMessageType('ControlRequest', (_message.Message,), { - 'DESCRIPTOR': _CONTROLREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ControlRequest) -}) + 'DESCRIPTOR' : _CONTROLREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ControlRequest) + }) _sym_db.RegisterMessage(ControlRequest) ControlResponse = _reflection.GeneratedProtocolMessageType('ControlResponse', (_message.Message,), { - 'DESCRIPTOR': _CONTROLRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ControlResponse) -}) + 'DESCRIPTOR' : _CONTROLRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ControlResponse) + }) _sym_db.RegisterMessage(ControlResponse) ReportResponse = _reflection.GeneratedProtocolMessageType('ReportResponse', (_message.Message,), { - 'DESCRIPTOR': _REPORTRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReportResponse) -}) + 'DESCRIPTOR' : _REPORTRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReportResponse) + }) _sym_db.RegisterMessage(ReportResponse) ConnectionRequest = _reflection.GeneratedProtocolMessageType('ConnectionRequest', (_message.Message,), { - 'DESCRIPTOR': _CONNECTIONREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ConnectionRequest) -}) + 'DESCRIPTOR' : _CONNECTIONREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ConnectionRequest) + }) _sym_db.RegisterMessage(ConnectionRequest) ConnectionResponse = _reflection.GeneratedProtocolMessageType('ConnectionResponse', (_message.Message,), { - 'DESCRIPTOR': _CONNECTIONRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ConnectionResponse) -}) + 'DESCRIPTOR' : _CONNECTIONRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ConnectionResponse) + }) _sym_db.RegisterMessage(ConnectionResponse) _MODELSERVICE = DESCRIPTOR.services_by_name['ModelService'] @@ -249,77 +250,77 @@ _REDUCER = DESCRIPTOR.services_by_name['Reducer'] _CONNECTOR = DESCRIPTOR.services_by_name['Connector'] _COMBINER = DESCRIPTOR.services_by_name['Combiner'] -if _descriptor._USE_C_DESCRIPTORS is False: - - DESCRIPTOR._options = None - _STATUSTYPE._serialized_start = 2412 - _STATUSTYPE._serialized_end = 2544 - _CHANNEL._serialized_start = 2547 - _CHANNEL._serialized_end = 2681 - _MODELSTATUS._serialized_start = 2683 - _MODELSTATUS._serialized_end = 2753 - _ROLE._serialized_start = 2755 - _ROLE._serialized_end = 2811 - _COMMAND._serialized_start = 2813 - _COMMAND._serialized_end = 2887 - _CONNECTIONSTATUS._serialized_start = 2889 - _CONNECTIONSTATUS._serialized_end = 2962 - _RESPONSE._serialized_start = 41 - _RESPONSE._serialized_end = 99 - _STATUS._serialized_start = 102 - _STATUS._serialized_end = 370 - _STATUS_LOGLEVEL._serialized_start = 304 - _STATUS_LOGLEVEL._serialized_end = 370 - _MODELUPDATEREQUEST._serialized_start = 373 - _MODELUPDATEREQUEST._serialized_end = 544 - _MODELUPDATE._serialized_start = 547 - _MODELUPDATE._serialized_end = 722 - _MODELVALIDATIONREQUEST._serialized_start = 725 - _MODELVALIDATIONREQUEST._serialized_end = 922 - _MODELVALIDATION._serialized_start = 925 - _MODELVALIDATION._serialized_end = 1093 - _MODELREQUEST._serialized_start = 1096 - _MODELREQUEST._serialized_end = 1233 - _MODELRESPONSE._serialized_start = 1235 - _MODELRESPONSE._serialized_end = 1328 - _GETGLOBALMODELREQUEST._serialized_start = 1330 - _GETGLOBALMODELREQUEST._serialized_end = 1415 - _GETGLOBALMODELRESPONSE._serialized_start = 1417 - _GETGLOBALMODELRESPONSE._serialized_end = 1521 - _HEARTBEAT._serialized_start = 1523 - _HEARTBEAT._serialized_end = 1564 - _CLIENTAVAILABLEMESSAGE._serialized_start = 1566 - _CLIENTAVAILABLEMESSAGE._serialized_end = 1653 - _LISTCLIENTSREQUEST._serialized_start = 1655 - _LISTCLIENTSREQUEST._serialized_end = 1737 - _CLIENTLIST._serialized_start = 1739 - _CLIENTLIST._serialized_end = 1781 - _CLIENT._serialized_start = 1783 - _CLIENT._serialized_end = 1831 - _REASSIGNREQUEST._serialized_start = 1833 - _REASSIGNREQUEST._serialized_end = 1942 - _RECONNECTREQUEST._serialized_start = 1944 - _RECONNECTREQUEST._serialized_end = 2043 - _PARAMETER._serialized_start = 2045 - _PARAMETER._serialized_end = 2084 - _CONTROLREQUEST._serialized_start = 2086 - _CONTROLREQUEST._serialized_end = 2170 - _CONTROLRESPONSE._serialized_start = 2172 - _CONTROLRESPONSE._serialized_end = 2242 - _REPORTRESPONSE._serialized_start = 2244 - _REPORTRESPONSE._serialized_end = 2326 - _CONNECTIONREQUEST._serialized_start = 2328 - _CONNECTIONREQUEST._serialized_end = 2347 - _CONNECTIONRESPONSE._serialized_start = 2349 - _CONNECTIONRESPONSE._serialized_end = 2409 - _MODELSERVICE._serialized_start = 2964 - _MODELSERVICE._serialized_end = 3086 - _CONTROL._serialized_start = 3089 - _CONTROL._serialized_end = 3316 - _REDUCER._serialized_start = 3318 - _REDUCER._serialized_end = 3404 - _CONNECTOR._serialized_start = 3407 - _CONNECTOR._serialized_end = 3834 - _COMBINER._serialized_start = 3837 - _COMBINER._serialized_end = 4439 +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _STATUSTYPE._serialized_start=2412 + _STATUSTYPE._serialized_end=2544 + _CHANNEL._serialized_start=2547 + _CHANNEL._serialized_end=2681 + _MODELSTATUS._serialized_start=2683 + _MODELSTATUS._serialized_end=2753 + _ROLE._serialized_start=2755 + _ROLE._serialized_end=2811 + _COMMAND._serialized_start=2813 + _COMMAND._serialized_end=2887 + _CONNECTIONSTATUS._serialized_start=2889 + _CONNECTIONSTATUS._serialized_end=2962 + _RESPONSE._serialized_start=41 + _RESPONSE._serialized_end=99 + _STATUS._serialized_start=102 + _STATUS._serialized_end=370 + _STATUS_LOGLEVEL._serialized_start=304 + _STATUS_LOGLEVEL._serialized_end=370 + _MODELUPDATEREQUEST._serialized_start=373 + _MODELUPDATEREQUEST._serialized_end=544 + _MODELUPDATE._serialized_start=547 + _MODELUPDATE._serialized_end=722 + _MODELVALIDATIONREQUEST._serialized_start=725 + _MODELVALIDATIONREQUEST._serialized_end=922 + _MODELVALIDATION._serialized_start=925 + _MODELVALIDATION._serialized_end=1093 + _MODELREQUEST._serialized_start=1096 + _MODELREQUEST._serialized_end=1233 + _MODELRESPONSE._serialized_start=1235 + _MODELRESPONSE._serialized_end=1328 + _GETGLOBALMODELREQUEST._serialized_start=1330 + _GETGLOBALMODELREQUEST._serialized_end=1415 + _GETGLOBALMODELRESPONSE._serialized_start=1417 + _GETGLOBALMODELRESPONSE._serialized_end=1521 + _HEARTBEAT._serialized_start=1523 + _HEARTBEAT._serialized_end=1564 + _CLIENTAVAILABLEMESSAGE._serialized_start=1566 + _CLIENTAVAILABLEMESSAGE._serialized_end=1653 + _LISTCLIENTSREQUEST._serialized_start=1655 + _LISTCLIENTSREQUEST._serialized_end=1737 + _CLIENTLIST._serialized_start=1739 + _CLIENTLIST._serialized_end=1781 + _CLIENT._serialized_start=1783 + _CLIENT._serialized_end=1831 + _REASSIGNREQUEST._serialized_start=1833 + _REASSIGNREQUEST._serialized_end=1942 + _RECONNECTREQUEST._serialized_start=1944 + _RECONNECTREQUEST._serialized_end=2043 + _PARAMETER._serialized_start=2045 + _PARAMETER._serialized_end=2084 + _CONTROLREQUEST._serialized_start=2086 + _CONTROLREQUEST._serialized_end=2170 + _CONTROLRESPONSE._serialized_start=2172 + _CONTROLRESPONSE._serialized_end=2242 + _REPORTRESPONSE._serialized_start=2244 + _REPORTRESPONSE._serialized_end=2326 + _CONNECTIONREQUEST._serialized_start=2328 + _CONNECTIONREQUEST._serialized_end=2347 + _CONNECTIONRESPONSE._serialized_start=2349 + _CONNECTIONRESPONSE._serialized_end=2409 + _MODELSERVICE._serialized_start=2964 + _MODELSERVICE._serialized_end=3086 + _CONTROL._serialized_start=3089 + _CONTROL._serialized_end=3386 + _REDUCER._serialized_start=3388 + _REDUCER._serialized_end=3474 + _CONNECTOR._serialized_start=3477 + _CONNECTOR._serialized_end=3904 + _COMBINER._serialized_start=3907 + _COMBINER._serialized_end=4509 # @@protoc_insertion_point(module_scope) diff --git a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py index 7cf017b34..9590e2b5c 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py @@ -2,8 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from fedn.common.net.grpc import \ - fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 +from fedn.common.net.grpc import fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 class ModelServiceStub(object): @@ -16,15 +15,15 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Upload = channel.stream_unary( - '/grpc.ModelService/Upload', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - ) + '/grpc.ModelService/Upload', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + ) self.Download = channel.unary_stream( - '/grpc.ModelService/Download', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - ) + '/grpc.ModelService/Download', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + ) class ModelServiceServicer(object): @@ -45,60 +44,59 @@ def Download(self, request, context): def add_ModelServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'Upload': grpc.stream_unary_rpc_method_handler( - servicer.Upload, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, - ), - 'Download': grpc.unary_stream_rpc_method_handler( - servicer.Download, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, - ), + 'Upload': grpc.stream_unary_rpc_method_handler( + servicer.Upload, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + ), + 'Download': grpc.unary_stream_rpc_method_handler( + servicer.Download, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.ModelService', rpc_method_handlers) + 'grpc.ModelService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class ModelService(object): """Missing associated documentation comment in .proto file.""" @staticmethod def Upload(request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.stream_unary(request_iterator, target, '/grpc.ModelService/Upload', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Download(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.ModelService/Download', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ControlStub(object): @@ -111,25 +109,30 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Start = channel.unary_unary( - '/grpc.Control/Start', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - ) + '/grpc.Control/Start', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Stop = channel.unary_unary( - '/grpc.Control/Stop', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - ) + '/grpc.Control/Stop', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Configure = channel.unary_unary( - '/grpc.Control/Configure', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - ) + '/grpc.Control/Configure', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + ) + self.FlushAggregationQueue = channel.unary_unary( + '/grpc.Control/FlushAggregationQueue', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Report = channel.unary_unary( - '/grpc.Control/Report', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - ) + '/grpc.Control/Report', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + ) class ControlServicer(object): @@ -153,6 +156,12 @@ def Configure(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def FlushAggregationQueue(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Report(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -162,104 +171,125 @@ def Report(self, request, context): def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { - 'Start': grpc.unary_unary_rpc_method_handler( - servicer.Start, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, - ), - 'Stop': grpc.unary_unary_rpc_method_handler( - servicer.Stop, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, - ), - 'Configure': grpc.unary_unary_rpc_method_handler( - servicer.Configure, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, - ), - 'Report': grpc.unary_unary_rpc_method_handler( - servicer.Report, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, - ), + 'Start': grpc.unary_unary_rpc_method_handler( + servicer.Start, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Stop': grpc.unary_unary_rpc_method_handler( + servicer.Stop, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Configure': grpc.unary_unary_rpc_method_handler( + servicer.Configure, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, + ), + 'FlushAggregationQueue': grpc.unary_unary_rpc_method_handler( + servicer.FlushAggregationQueue, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Report': grpc.unary_unary_rpc_method_handler( + servicer.Report, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Control', rpc_method_handlers) + 'grpc.Control', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Control(object): """Missing associated documentation comment in .proto file.""" @staticmethod def Start(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Start', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Stop(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Stop', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Configure(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Configure', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def FlushAggregationQueue(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/grpc.Control/FlushAggregationQueue', + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Report(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Report', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ReducerStub(object): @@ -272,10 +302,10 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.GetGlobalModel = channel.unary_unary( - '/grpc.Reducer/GetGlobalModel', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, - ) + '/grpc.Reducer/GetGlobalModel', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + ) class ReducerServicer(object): @@ -290,38 +320,37 @@ def GetGlobalModel(self, request, context): def add_ReducerServicer_to_server(servicer, server): rpc_method_handlers = { - 'GetGlobalModel': grpc.unary_unary_rpc_method_handler( - servicer.GetGlobalModel, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, - ), + 'GetGlobalModel': grpc.unary_unary_rpc_method_handler( + servicer.GetGlobalModel, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Reducer', rpc_method_handlers) + 'grpc.Reducer', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Reducer(object): """Missing associated documentation comment in .proto file.""" @staticmethod def GetGlobalModel(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Reducer/GetGlobalModel', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ConnectorStub(object): @@ -334,40 +363,40 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.AllianceStatusStream = channel.unary_stream( - '/grpc.Connector/AllianceStatusStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - ) + '/grpc.Connector/AllianceStatusStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + ) self.SendStatus = channel.unary_unary( - '/grpc.Connector/SendStatus', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/SendStatus', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ListActiveClients = channel.unary_unary( - '/grpc.Connector/ListActiveClients', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, - ) + '/grpc.Connector/ListActiveClients', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, + ) self.AcceptingClients = channel.unary_unary( - '/grpc.Connector/AcceptingClients', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, - ) + '/grpc.Connector/AcceptingClients', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + ) self.SendHeartbeat = channel.unary_unary( - '/grpc.Connector/SendHeartbeat', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/SendHeartbeat', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ReassignClient = channel.unary_unary( - '/grpc.Connector/ReassignClient', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/ReassignClient', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ReconnectClient = channel.unary_unary( - '/grpc.Connector/ReconnectClient', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/ReconnectClient', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) class ConnectorServicer(object): @@ -423,170 +452,169 @@ def ReconnectClient(self, request, context): def add_ConnectorServicer_to_server(servicer, server): rpc_method_handlers = { - 'AllianceStatusStream': grpc.unary_stream_rpc_method_handler( - servicer.AllianceStatusStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - ), - 'SendStatus': grpc.unary_unary_rpc_method_handler( - servicer.SendStatus, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ListActiveClients': grpc.unary_unary_rpc_method_handler( - servicer.ListActiveClients, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, - ), - 'AcceptingClients': grpc.unary_unary_rpc_method_handler( - servicer.AcceptingClients, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, - ), - 'SendHeartbeat': grpc.unary_unary_rpc_method_handler( - servicer.SendHeartbeat, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ReassignClient': grpc.unary_unary_rpc_method_handler( - servicer.ReassignClient, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ReconnectClient': grpc.unary_unary_rpc_method_handler( - servicer.ReconnectClient, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), + 'AllianceStatusStream': grpc.unary_stream_rpc_method_handler( + servicer.AllianceStatusStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + ), + 'SendStatus': grpc.unary_unary_rpc_method_handler( + servicer.SendStatus, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ListActiveClients': grpc.unary_unary_rpc_method_handler( + servicer.ListActiveClients, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, + ), + 'AcceptingClients': grpc.unary_unary_rpc_method_handler( + servicer.AcceptingClients, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, + ), + 'SendHeartbeat': grpc.unary_unary_rpc_method_handler( + servicer.SendHeartbeat, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ReassignClient': grpc.unary_unary_rpc_method_handler( + servicer.ReassignClient, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ReconnectClient': grpc.unary_unary_rpc_method_handler( + servicer.ReconnectClient, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Connector', rpc_method_handlers) + 'grpc.Connector', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Connector(object): """Missing associated documentation comment in .proto file.""" @staticmethod def AllianceStatusStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Connector/AllianceStatusStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendStatus(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/SendStatus', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ListActiveClients(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ListActiveClients', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def AcceptingClients(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/AcceptingClients', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendHeartbeat(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/SendHeartbeat', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ReassignClient(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ReassignClient', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ReconnectClient(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ReconnectClient', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class CombinerStub(object): @@ -599,45 +627,45 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.ModelUpdateRequestStream = channel.unary_stream( - '/grpc.Combiner/ModelUpdateRequestStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - ) + '/grpc.Combiner/ModelUpdateRequestStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + ) self.ModelUpdateStream = channel.unary_stream( - '/grpc.Combiner/ModelUpdateStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - ) + '/grpc.Combiner/ModelUpdateStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + ) self.ModelValidationRequestStream = channel.unary_stream( - '/grpc.Combiner/ModelValidationRequestStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - ) + '/grpc.Combiner/ModelValidationRequestStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + ) self.ModelValidationStream = channel.unary_stream( - '/grpc.Combiner/ModelValidationStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - ) + '/grpc.Combiner/ModelValidationStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + ) self.SendModelUpdateRequest = channel.unary_unary( - '/grpc.Combiner/SendModelUpdateRequest', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelUpdateRequest', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelUpdate = channel.unary_unary( - '/grpc.Combiner/SendModelUpdate', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelUpdate', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelValidationRequest = channel.unary_unary( - '/grpc.Combiner/SendModelValidationRequest', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelValidationRequest', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelValidation = channel.unary_unary( - '/grpc.Combiner/SendModelValidation', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelValidation', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) class CombinerServicer(object): @@ -695,189 +723,188 @@ def SendModelValidation(self, request, context): def add_CombinerServicer_to_server(servicer, server): rpc_method_handlers = { - 'ModelUpdateRequestStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelUpdateRequestStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - ), - 'ModelUpdateStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelUpdateStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - ), - 'ModelValidationRequestStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelValidationRequestStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - ), - 'ModelValidationStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelValidationStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - ), - 'SendModelUpdateRequest': grpc.unary_unary_rpc_method_handler( - servicer.SendModelUpdateRequest, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelUpdate': grpc.unary_unary_rpc_method_handler( - servicer.SendModelUpdate, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelValidationRequest': grpc.unary_unary_rpc_method_handler( - servicer.SendModelValidationRequest, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelValidation': grpc.unary_unary_rpc_method_handler( - servicer.SendModelValidation, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), + 'ModelUpdateRequestStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelUpdateRequestStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + ), + 'ModelUpdateStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelUpdateStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + ), + 'ModelValidationRequestStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelValidationRequestStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + ), + 'ModelValidationStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelValidationStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + ), + 'SendModelUpdateRequest': grpc.unary_unary_rpc_method_handler( + servicer.SendModelUpdateRequest, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelUpdate': grpc.unary_unary_rpc_method_handler( + servicer.SendModelUpdate, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelValidationRequest': grpc.unary_unary_rpc_method_handler( + servicer.SendModelValidationRequest, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelValidation': grpc.unary_unary_rpc_method_handler( + servicer.SendModelValidation, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Combiner', rpc_method_handlers) + 'grpc.Combiner', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Combiner(object): """Missing associated documentation comment in .proto file.""" @staticmethod def ModelUpdateRequestStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelUpdateRequestStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelUpdateStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelUpdateStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelValidationRequestStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelValidationRequestStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelValidationStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelValidationStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelUpdateRequest(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelUpdateRequest', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelUpdate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelUpdate', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelValidationRequest(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelValidationRequest', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelValidation(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelValidation', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 92af569ea..0a3e28cdc 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from google.protobuf.json_format import MessageToDict @@ -18,6 +19,7 @@ def __init__(self, mongo_config, network_id): self.rounds = self.mdb['control.rounds'] self.sessions = self.mdb['control.sessions'] self.validations = self.mdb['control.validations'] + self.clients = self.mdb['network.clients'] except Exception as e: print("FAILED TO CONNECT TO MONGO, {}".format(e), flush=True) self.status = None @@ -82,3 +84,17 @@ def set_round_data(self, round_data): """ self.rounds.update_one({'round_id': str(round_data['round_id'])}, { '$push': {'reducer': round_data}}, True) + + def update_client_status(self, client_name, status): + """ Update client status in statestore. + :param client_name: The client name + :type client_name: str + :param status: The client status + :type status: str + :return: None + """ + datetime_now = datetime.now() + filter_query = {"name": client_name} + + update_query = {"$set": {"last_seen": datetime_now, "status": status}} + self.clients.update_one(filter_query, update_query) diff --git a/fedn/fedn/network/__init__.py b/fedn/fedn/network/__init__.py index 52ce8c9c3..ec5dfd71a 100644 --- a/fedn/fedn/network/__init__.py +++ b/fedn/fedn/network/__init__.py @@ -1,3 +1,3 @@ -# -# Scaleout Systems AB -# __author__ = 'Morgan Ekmefjord morgan@scaleout.se' +""" The statestore package is responsible for storing various states of the federated network. Such as announced combiners and assigned clients. It also stores metadata about +models, rounds, sessions, compute packages and model validations. """ +# flake8: noqa diff --git a/fedn/fedn/network/api/__init__.py b/fedn/fedn/network/api/__init__.py new file mode 100644 index 000000000..9cb788d09 --- /dev/null +++ b/fedn/fedn/network/api/__init__.py @@ -0,0 +1,3 @@ +""" API module for the FEDn network. Includes a REST-API server to interact with the controller +and statestore.""" +# flake8: noqa diff --git a/fedn/fedn/network/api/client.py b/fedn/fedn/network/api/client.py new file mode 100644 index 000000000..0e0a48a52 --- /dev/null +++ b/fedn/fedn/network/api/client.py @@ -0,0 +1,280 @@ +import uuid + +import requests + +__all__ = ['APIClient'] + + +class APIClient: + """ An API client for interacting with the statestore and controller. + + :param host: The host of the api server. + :type host: str + :param port: The port of the api server. + :type port: int + :param secure: Whether to use https. + :type secure: bool + :param verify: Whether to verify the server certificate. + :type verify: bool + """ + + def __init__(self, host, port, secure=False, verify=False): + self.host = host + self.port = port + self.secure = secure + self.verify = verify + + def _get_url(self, endpoint): + if self.secure: + protocol = 'https' + else: + protocol = 'http' + return f'{protocol}://{self.host}:{self.port}/{endpoint}' + + def get_model_trail(self): + """ Get the model trail. + + :return: The model trail as dict including commit timestamp. + :rtype: dict + """ + response = requests.get(self._get_url('get_model_trail'), verify=self.verify) + return response.json() + + def list_clients(self): + """ Get all clients from the statestore. + + return: All clients. + rtype: dict + """ + response = requests.get(self._get_url('list_clients')) + return response.json() + + def get_active_clients(self, combiner_id): + """ Get all active clients from the statestore. + + :param combiner_id: The combiner id to get active clients for. + :type combiner_id: str + :return: All active clients. + :rtype: dict + """ + response = requests.get(self._get_url('get_active_clients'), params={'combiner': combiner_id}, verify=self.verify) + return response.json() + + def get_client_config(self, checksum=True): + """ Get the controller configuration. Optionally include the checksum. + The config is used for clients to connect to the controller and ask for combiner assignment. + + :param checksum: Whether to include the checksum of the package. + :type checksum: bool + :return: The client configuration. + :rtype: dict + """ + response = requests.get(self._get_url('get_client_config'), params={'checksum': checksum}, verify=self.verify) + return response.json() + + def list_combiners(self): + """ Get all combiners in the network. + + :return: All combiners with info. + :rtype: dict + """ + response = requests.get(self._get_url('list_combiners')) + return response.json() + + def get_combiner(self, combiner_id): + """ Get a combiner from the statestore. + + :param combiner_id: The combiner id to get. + :type combiner_id: str + :return: The combiner info. + :rtype: dict + """ + response = requests.get(self._get_url(f'get_combiner?combiner={combiner_id}'), verify=self.verify) + return response.json() + + def list_rounds(self): + """ Get all rounds from the statestore. + + :return: All rounds with config and metrics. + :rtype: dict + """ + response = requests.get(self._get_url('list_rounds')) + return response.json() + + def get_round(self, round_id): + """ Get a round from the statestore. + + :param round_id: The round id to get. + :type round_id: str + :return: The round config and metrics. + :rtype: dict + """ + response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify) + return response.json() + + def start_session(self, session_id=None, round_timeout=180, rounds=5, round_buffer_size=-1, delete_models=True, + validate=True, helper='kerashelper', min_clients=1, requested_clients=8): + """ Start a new session. + + :param session_id: The session id to start. + :type session_id: str + :param round_timeout: The round timeout to use in seconds. + :type round_timeout: int + :param rounds: The number of rounds to perform. + :type rounds: int + :param round_buffer_size: The round buffer size to use. + :type round_buffer_size: int + :param delete_models: Whether to delete models after each round at combiner (save storage). + :type delete_models: bool + :param validate: Whether to validate the model after each round. + :type validate: bool + :param helper: The helper type to use. + :type helper: str + :param min_clients: The minimum number of clients required. + :type min_clients: int + :param requested_clients: The requested number of clients. + :type requested_clients: int + :return: A dict with success or failure message and session config. + :rtype: dict + """ + # If session id is None, generate a random session id. + if session_id is None: + session_id = str(uuid.uuid4()) + response = requests.post(self._get_url('start_session'), json={ + 'session_id': session_id, + 'round_timeout': round_timeout, + 'rounds': rounds, + 'round_buffer_size': round_buffer_size, + 'delete_models': delete_models, + 'validate': validate, + 'helper': helper, + 'min_clients': min_clients, + 'requested_clients': requested_clients + }, verify=self.verify + ) + return response.json() + + def list_sessions(self): + """ Get all sessions from the statestore. + + :return: All sessions in dict. + :rtype: dict + """ + response = requests.get(self._get_url('list_sessions'), verify=self.verify) + return response.json() + + def get_session(self, session_id): + """ Get a session from the statestore. + + :param session_id: The session id to get. + :type session_id: str + :return: The session as a json object. + :rtype: dict + """ + response = requests.get(self._get_url(f'get_session?session_id={session_id}'), self.verify) + return response.json() + + def set_package(self, path, helper): + """ Set the compute package in the statestore. + + :param path: The file path of the compute package to set. + :type path: str + :param helper: The helper type to use. + :type helper: str + :return: A dict with success or failure message. + :rtype: dict + """ + with open(path, 'rb') as file: + response = requests.post(self._get_url('set_package'), files={'file': file}, data={'helper': helper}, verify=self.verify) + return response.json() + + def get_package(self): + """ Get the compute package from the statestore. + + :return: The compute package with info. + :rtype: dict + """ + response = requests.get(self._get_url('get_package'), verify=self.verify) + return response.json() + + def download_package(self, path): + """ Download the compute package. + + :param path: The path to download the compute package to. + :type path: str + :return: Message with success or failure. + :rtype: dict + """ + response = requests.get(self._get_url('download_package'), verify=self.verify) + if response.status_code == 200: + with open(path, 'wb') as file: + file.write(response.content) + return {'success': True, 'message': 'Package downloaded successfully.'} + else: + return {'success': False, 'message': 'Failed to download package.'} + + def get_package_checksum(self): + """ Get the checksum of the compute package. + + :return: The checksum. + :rtype: dict + """ + response = requests.get(self._get_url('get_package_checksum'), verify=self.verify) + return response.json() + + def get_latest_model(self): + """ Get the latest model from the statestore. + + :return: The latest model id. + :rtype: dict + """ + response = requests.get(self._get_url('get_latest_model'), verify=self.verify) + return response.json() + + def get_initial_model(self): + """ Get the initial model from the statestore. + + :return: The initial model id. + :rtype: dict + """ + response = requests.get(self._get_url('get_initial_model'), verify=self.verify) + return response.json() + + def set_initial_model(self, path): + """ Set the initial model in the statestore and upload to model repository. + + :param path: The file path of the initial model to set. + :type path: str + :return: A dict with success or failure message. + :rtype: dict + """ + with open(path, 'rb') as file: + response = requests.post(self._get_url('set_initial_model'), files={'file': file}, verify=self.verify) + return response.json() + + def get_controller_status(self): + """ Get the status of the controller. + + :return: The status of the controller. + :rtype: dict + """ + response = requests.get(self._get_url('get_controller_status'), verify=self.verify) + return response.json() + + def get_events(self, **kwargs): + """ Get the events from the statestore. Pass kwargs to filter events. + + :return: The events in dict + :rtype: dict + """ + response = requests.get(self._get_url('get_events'), params=kwargs, verify=self.verify) + return response.json() + + def list_validations(self, **kwargs): + """ Get all validations from the statestore. Pass kwargs to filter validations. + + :return: All validations in dict. + :rtype: dict + """ + response = requests.get(self._get_url('list_validations'), params=kwargs, verify=self.verify) + return response.json() diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py new file mode 100644 index 000000000..3a20e3502 --- /dev/null +++ b/fedn/fedn/network/api/interface.py @@ -0,0 +1,892 @@ +import base64 +import copy +import os +import threading +from io import BytesIO + +from flask import jsonify, send_from_directory +from werkzeug.utils import secure_filename + +from fedn.common.config import get_controller_config, get_network_config +from fedn.network.combiner.interfaces import (CombinerInterface, + CombinerUnavailableError) +from fedn.network.dashboard.plots import Plot +from fedn.network.state import ReducerState, ReducerStateToString +from fedn.utils.checksum import sha + +__all__ = ("API",) + + +class API: + """The API class is a wrapper for the statestore. It is used to expose the statestore to the network API.""" + + def __init__(self, statestore, control): + self.statestore = statestore + self.control = control + self.name = "api" + + def _to_dict(self): + """Convert the object to a dict. + + ::return: The object as a dict. + ::rtype: dict + """ + data = {"name": self.name} + return data + + def _get_combiner_report(self, combiner_id): + """Get report response from combiner. + + :param combiner_id: The combiner id to get report response from. + :type combiner_id: str + ::return: The report response from combiner. + ::rtype: dict + """ + # Get CombinerInterface (fedn.network.combiner.inferface.CombinerInterface) for combiner_id + combiner = self.control.network.get_combiner(combiner_id) + report = combiner.report + return report + + def _allowed_file_extension( + self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} + ): + """Check if file extension is allowed. + + :param filename: The filename to check. + :type filename: str + :return: True if file extension is allowed, else False. + :rtype: bool + """ + + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) + + def get_clients(self, limit=None, skip=None, status=False): + """Get all clients from the statestore. + + :return: All clients as a json response. + :rtype: :class:`flask.Response` + """ + # Will return list of ObjectId + response = self.statestore.list_clients(limit, skip, status) + + arr = [] + + for element in response["result"]: + obj = { + "id": element["name"], + "combiner": element["combiner"], + "combiner_preferred": element["combiner_preferred"], + "ip": element["ip"], + "status": element["status"], + "last_seen": element["last_seen"], + } + + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) + + def get_active_clients(self, combiner_id): + """Get all active clients, i.e that are assigned to a combiner. + A report request to the combiner is neccessary to determine if a client is active or not. + + :param combiner_id: The combiner id to get active clients for. + :type combiner_id: str + :return: All active clients as a json response. + :rtype: :class:`flask.Response` + """ + # Get combiner interface object + combiner = self.control.network.get_combiner(combiner_id) + if combiner is None: + return ( + jsonify( + { + "success": False, + "message": f"Combiner {combiner_id} not found.", + } + ), + 404, + ) + response = combiner.list_active_clients() + return response + + def get_all_combiners(self, limit=None, skip=None): + """Get all combiners from the statestore. + + :return: All combiners as a json response. + :rtype: :class:`flask.Response` + """ + # Will return list of ObjectId + projection = {"name": True, "updated_at": True} + response = self.statestore.get_combiners(limit, skip, projection=projection) + arr = [] + for element in response["result"]: + obj = { + "name": element["name"], + "updated_at": element["updated_at"], + } + + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) + + def get_combiner(self, combiner_id): + """Get a combiner from the statestore. + + :param combiner_id: The combiner id to get. + :type combiner_id: str + :return: The combiner info dict as a json response. + :rtype: :class:`flask.Response` + """ + # Will return ObjectId + object = self.statestore.get_combiner(combiner_id) + payload = {} + id = object["name"] + info = { + "address": object["address"], + "fqdn": object["fqdn"], + "parent_reducer": object["parent"]["name"], + "port": object["port"], + "report": object["report"], + "updated_at": object["updated_at"], + } + payload[id] = info + + return jsonify(payload) + + def get_all_sessions(self, limit=None, skip=None): + """Get all sessions from the statestore. + + :return: All sessions as a json response. + :rtype: :class:`flask.Response` + """ + sessions_object = self.statestore.get_sessions(limit, skip) + if sessions_object is None: + return ( + jsonify({"success": False, "message": "No sessions found."}), + 404, + ) + arr = [] + for element in sessions_object["result"]: + obj = element["session_config"][0] + arr.append(obj) + + result = {"result": arr, "count": sessions_object["count"]} + + return jsonify(result) + + def get_session(self, session_id): + """Get a session from the statestore. + + :param session_id: The session id to get. + :type session_id: str + :return: The session info dict as a json response. + :rtype: :class:`flask.Response` + """ + session_object = self.statestore.get_session(session_id) + if session_object is None: + return ( + jsonify( + { + "success": False, + "message": f"Session {session_id} not found.", + } + ), + 404, + ) + payload = {} + id = session_object["session_id"] + info = session_object["session_config"][0] + payload[id] = info + return jsonify(payload) + + def set_compute_package(self, file, helper_type): + """Set the compute package in the statestore. + + :param file: The compute package to set. + :type file: file + :return: A json response with success or failure message. + :rtype: :class:`flask.Response` + """ + + if file and self._allowed_file_extension(file.filename): + filename = secure_filename(file.filename) + # TODO: make configurable, perhaps in config.py or package.py + file_path = os.path.join("/app/client/package/", filename) + file.save(file_path) + + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): + return ( + jsonify( + { + "success": False, + "message": "Reducer is in instructing or monitoring state." + "Cannot set compute package.", + } + ), + 400, + ) + + self.control.set_compute_package(filename, file_path) + self.statestore.set_helper(helper_type) + + success = self.statestore.set_compute_package(filename) + if not success: + return ( + jsonify( + { + "success": False, + "message": "Failed to set compute package.", + } + ), + 400, + ) + return jsonify({"success": True, "message": "Compute package set."}) + + def _get_compute_package_name(self): + """Get the compute package name from the statestore. + + :return: The compute package name. + :rtype: str + """ + package_objects = self.statestore.get_compute_package() + if package_objects is None: + message = "No compute package found." + return None, message + else: + try: + name = package_objects["filename"] + except KeyError as e: + message = "No compute package found. Key error." + print(e) + return None, message + return name, "success" + + def get_compute_package(self): + """Get the compute package from the statestore. + + :return: The compute package as a json response. + :rtype: :class:`flask.Response` + """ + package_object = self.statestore.get_compute_package() + if package_object is None: + return ( + jsonify( + {"success": False, "message": "No compute package found."} + ), + 404, + ) + payload = {} + id = str(package_object["_id"]) + info = { + "filename": package_object["filename"], + "helper": package_object["helper"], + } + payload[id] = info + return jsonify(payload) + + def download_compute_package(self, name): + """Download the compute package. + + :return: The compute package as a json object. + :rtype: :class:`flask.Response` + """ + if name is None: + name, message = self._get_compute_package_name() + if name is None: + return jsonify({"success": False, "message": message}), 404 + try: + mutex = threading.Lock() + mutex.acquire() + # TODO: make configurable, perhaps in config.py or package.py + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) + except Exception: + try: + data = self.control.get_compute_package(name) + # TODO: make configurable, perhaps in config.py or package.py + file_path = os.path.join("/app/client/package/", name) + with open(file_path, "wb") as fh: + fh.write(data) + # TODO: make configurable, perhaps in config.py or package.py + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) + except Exception: + raise + finally: + mutex.release() + + def _create_checksum(self, name=None): + """Create the checksum of the compute package. + + :param name: The name of the compute package. + :type name: str + :return: Success or failure boolean, message and the checksum. + :rtype: bool, str, str + """ + + if name is None: + name, message = self._get_compute_package_name() + if name is None: + return False, message, "" + file_path = os.path.join( + "/app/client/package/", name + ) # TODO: make configurable, perhaps in config.py or package.py + try: + sum = str(sha(file_path)) + except FileNotFoundError: + sum = "" + message = "File not found." + return True, message, sum + + def get_checksum(self, name): + """Get the checksum of the compute package. + + :param name: The name of the compute package. + :type name: str + :return: The checksum as a json object. + :rtype: :py:class:`flask.Response` + """ + + success, message, sum = self._create_checksum(name) + if not success: + return jsonify({"success": False, "message": message}), 404 + payload = {"checksum": sum} + + return jsonify(payload) + + def get_controller_status(self): + """Get the status of the controller. + + :return: The status of the controller as a json object. + :rtype: :py:class:`flask.Response` + """ + return jsonify({"state": ReducerStateToString(self.control.state())}) + + def get_events(self, **kwargs): + """Get the events of the federated network. + + :return: The events as a json object. + :rtype: :py:class:`flask.Response` + """ + response = self.statestore.get_events(**kwargs) + + result = response["result"] + if result is None: + return ( + jsonify({"success": False, "message": "No events found."}), + 404, + ) + + events = [] + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) + + def get_all_validations(self, **kwargs): + """Get all validations from the statestore. + + :return: All validations as a json response. + :rtype: :class:`flask.Response` + """ + validations_objects = self.statestore.get_validations(**kwargs) + if validations_objects is None: + return ( + jsonify( + { + "success": False, + "message": "No validations found.", + "filter_used": kwargs, + } + ), + 404, + ) + payload = {} + for object in validations_objects: + id = str(object["_id"]) + info = { + "model_id": object["modelId"], + "data": object["data"], + "timestamp": object["timestamp"], + "meta": object["meta"], + "sender": object["sender"], + "receiver": object["receiver"], + } + payload[id] = info + return jsonify(payload) + + def add_combiner( + self, combiner_id, secure_grpc, address, remote_addr, fqdn, port + ): + """Add a combiner to the network. + + :param combiner_id: The combiner id to add. + :type combiner_id: str + :param secure_grpc: Whether to use secure grpc or not. + :type secure_grpc: bool + :param name: The name of the combiner. + :type name: str + :param address: The address of the combiner. + :type address: str + :param remote_addr: The remote address of the combiner. + :type remote_addr: str + :param fqdn: The fqdn of the combiner. + :type fqdn: str + :param port: The port of the combiner. + :type port: int + :return: Config of the combiner as a json response. + :rtype: :class:`flask.Response` + """ + # TODO: Any more required check for config? Formerly based on status: "retry" + if not self.control.idle(): + return jsonify( + { + "success": False, + "status": "retry", + "message": "Conroller is not in idle state, try again later. ", + } + ) + # Check if combiner already exists + combiner = self.control.network.get_combiner(combiner_id) + if not combiner: + if secure_grpc == "True": + certificate, key = self.certificate_manager.get_or_create( + address + ).get_keypair_raw() + _ = base64.b64encode(certificate) + _ = base64.b64encode(key) + + else: + certificate = None + key = None + + combiner_interface = CombinerInterface( + parent=self._to_dict(), + name=combiner_id, + address=address, + fqdn=fqdn, + port=port, + certificate=copy.deepcopy(certificate), + key=copy.deepcopy(key), + ip=remote_addr, + ) + + self.control.network.add_combiner(combiner_interface) + + # Check combiner now exists + combiner = self.control.network.get_combiner(combiner_id) + if not combiner: + return jsonify( + {"success": False, "message": "Combiner not added."} + ) + + payload = { + "success": True, + "message": "Combiner added successfully.", + "status": "added", + "storage": self.statestore.get_storage_backend(), + "statestore": self.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), + } + + return jsonify(payload) + + def add_client(self, client_id, preferred_combiner, remote_addr): + """Add a client to the network. + + :param client_id: The client id to add. + :type client_id: str + :param preferred_combiner: The preferred combiner for the client.If None, the combiner will be chosen based on availability. + :type preferred_combiner: str + :return: A json response with combiner assignment config. + :rtype: :class:`flask.Response` + """ + # Check if package has been set + package_object = self.statestore.get_compute_package() + if package_object is None: + return ( + jsonify( + { + "success": False, + "status": "retry", + "message": "No compute package found. Set package in controller.", + } + ), + 203, + ) + + # Assign client to combiner + if preferred_combiner: + combiner = self.control.network.get_combiner(preferred_combiner) + if combiner is None: + return ( + jsonify( + { + "success": False, + "message": f"Combiner {preferred_combiner} not found or unavailable.", + } + ), + 400, + ) + else: + combiner = self.control.network.find_available_combiner() + if combiner is None: + return ( + jsonify( + {"success": False, "message": "No combiner available."} + ), + 400, + ) + + client_config = { + "name": client_id, + "combiner_preferred": preferred_combiner, + "combiner": combiner.name, + "ip": remote_addr, + "status": "available", + } + # Add client to network + self.control.network.add_client(client_config) + + # Setup response containing information about the combiner for assinging the client + if combiner.certificate: + cert_b64 = base64.b64encode(combiner.certificate) + cert = str(cert_b64).split("'")[1] + else: + cert = None + + payload = { + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": "remote", # TODO: Make this configurable + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "helper_type": self.control.statestore.get_helper(), + } + print("Seding payload: ", payload, flush=True) + + return jsonify(payload) + + def get_initial_model(self): + """Get the initial model from the statestore. + + :return: The initial model as a json response. + :rtype: :class:`flask.Response` + """ + model_id = self.statestore.get_initial_model() + payload = {"model_id": model_id} + return jsonify(payload) + + def set_initial_model(self, file): + """Add an initial model to the network. + + :param file: The initial model to add. + :type file: file + :return: A json response with success or failure message. + :rtype: :class:`flask.Response` + """ + try: + object = BytesIO() + object.seek(0, 0) + file.seek(0) + object.write(file.read()) + helper = self.control.get_helper() + object.seek(0) + model = helper.load(object) + self.control.commit(file.filename, model) + except Exception as e: + print(e, flush=True) + return jsonify({"success": False, "message": e}) + + return jsonify( + {"success": True, "message": "Initial model added successfully."} + ) + + def get_latest_model(self): + """Get the latest model from the statestore. + + :return: The initial model as a json response. + :rtype: :class:`flask.Response` + """ + if self.statestore.get_latest_model(): + model_id = self.statestore.get_latest_model() + payload = {"model_id": model_id} + return jsonify(payload) + else: + return jsonify( + {"success": False, "message": "No initial model set."} + ) + + def get_models(self, session_id=None, limit=None, skip=None): + result = self.statestore.list_models(session_id, limit, skip) + + if result is None: + return ( + jsonify({"success": False, "message": "No models found."}), + 404, + ) + + arr = [] + + for model in result["result"]: + arr.append(model) + + result = {"result": arr, "count": result["count"]} + + return jsonify(result) + + def get_model_trail(self): + """Get the model trail for a given session. + + :param session: The session id to get the model trail for. + :type session: str + :return: The model trail for the given session as a json response. + :rtype: :class:`flask.Response` + """ + model_info = self.statestore.get_model_trail() + if model_info: + return jsonify(model_info) + else: + return jsonify( + {"success": False, "message": "No model trail available."} + ) + + def get_all_rounds(self): + """Get all rounds. + + :return: The rounds as json response. + :rtype: :class:`flask.Response` + """ + rounds_objects = self.statestore.get_rounds() + if rounds_objects is None: + jsonify({"success": False, "message": "No rounds available."}) + payload = {} + for object in rounds_objects: + id = object["round_id"] + if "reducer" in object.keys(): + reducer = object["reducer"] + else: + reducer = None + if "combiners" in object.keys(): + combiners = object["combiners"] + else: + combiners = None + + info = { + "reducer": reducer, + "combiners": combiners, + } + payload[id] = info + else: + return jsonify(payload) + + def get_round(self, round_id): + """Get a round. + + :param round_id: The round id to get. + :type round_id: str + :return: The round as json response. + :rtype: :class:`flask.Response` + """ + round_object = self.statestore.get_round(round_id) + if round_object is None: + return jsonify({"success": False, "message": "Round not found."}) + payload = { + "round_id": round_object["round_id"], + "reducer": round_object["reducer"], + "combiners": round_object["combiners"], + } + return jsonify(payload) + + def get_client_config(self, checksum=True): + """Get the client config. + + :return: The client config as json response. + :rtype: :py:class:`flask.Response` + """ + config = get_controller_config() + network_id = get_network_config() + port = config["port"] + host = config["host"] + payload = { + "network_id": network_id, + "discover_host": host, + "discover_port": port, + } + if checksum: + success, _, checksum_str = self._create_checksum() + if success: + payload["checksum"] = checksum_str + return jsonify(payload) + + def get_plot_data(self, feature=None): + """Get plot data. + + :return: The plot data as json response. + :rtype: :py:class:`flask.Response` + """ + + plot = Plot(self.control.statestore) + + try: + valid_metrics = plot.fetch_valid_metrics() + feature = feature or valid_metrics[0] + box_plot = plot.create_box_plot(feature) + except Exception as e: + valid_metrics = None + box_plot = None + print(e, flush=True) + + result = { + "valid_metrics": valid_metrics, + "box_plot": box_plot, + } + + return jsonify(result) + + def list_combiners_data(self, combiners): + """Get combiners data. + + :param combiners: The combiners to get data for. + :type combiners: list + :return: The combiners data as json response. + :rtype: :py:class:`flask.Response` + """ + + response = self.statestore.list_combiners_data(combiners) + + arr = [] + + # order list by combiner name + for element in response: + + obj = { + "combiner": element["_id"], + "count": element["count"], + } + + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + + def start_session( + self, + session_id, + rounds=5, + round_timeout=180, + round_buffer_size=-1, + delete_models=False, + validate=True, + helper="keras", + min_clients=1, + requested_clients=8, + ): + """Start a session. + + :param session_id: The session id to start. + :type session_id: str + :param rounds: The number of rounds to perform. + :type rounds: int + :param round_timeout: The round timeout to use in seconds. + :type round_timeout: int + :param round_buffer_size: The round buffer size to use. + :type round_buffer_size: int + :param delete_models: Whether to delete models after each round at combiner (save storage). + :type delete_models: bool + :param validate: Whether to validate the model after each round. + :type validate: bool + :param min_clients: The minimum number of clients required. + :type min_clients: int + :param requested_clients: The requested number of clients. + :type requested_clients: int + :return: A json response with success or failure message and session config. + :rtype: :class:`flask.Response` + """ + # Check if session already exists + session = self.statestore.get_session(session_id) + if session: + return jsonify( + {"success": False, "message": "Session already exists."} + ) + + # Check if session is running + if self.control.state() == ReducerState.monitoring: + return jsonify( + {"success": False, "message": "A session is already running."} + ) + + # Check available clients per combiner + clients_available = 0 + for combiner in self.control.network.get_combiners(): + try: + combiner_state = combiner.report() + nr_active_clients = combiner_state["nr_active_clients"] + clients_available = clients_available + int(nr_active_clients) + except CombinerUnavailableError as e: + # TODO: Handle unavailable combiner, stop session or continue? + print("COMBINER UNAVAILABLE: {}".format(e), flush=True) + continue + + if clients_available < min_clients: + return jsonify( + { + "success": False, + "message": "Not enough clients available to start session.", + } + ) + + # Check if validate is string and convert to bool + if isinstance(validate, str): + if validate.lower() == "true": + validate = True + else: + validate = False + + # Get lastest model as initial model for session + model_id = self.statestore.get_latest_model() + + # Setup session config + session_config = { + "session_id": session_id, + "round_timeout": round_timeout, + "buffer_size": round_buffer_size, + "model_id": model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": min_clients, + "clients_requested": requested_clients, + "task": (""), + "validate": validate, + "helper_type": helper, + } + + # Start session + threading.Thread( + target=self.control.session, args=(session_config,) + ).start() + + # Return success response + return jsonify( + { + "success": True, + "message": "Session started successfully.", + "config": session_config, + } + ) diff --git a/fedn/fedn/network/network.py b/fedn/fedn/network/api/network.py similarity index 59% rename from fedn/fedn/network/network.py rename to fedn/fedn/network/api/network.py index 15045db82..6fcaad053 100644 --- a/fedn/fedn/network/network.py +++ b/fedn/fedn/network/api/network.py @@ -4,9 +4,14 @@ CombinerUnavailableError) from fedn.network.loadbalancer.leastpacked import LeastPacked +__all__ = 'Network', + class Network: - """ FEDn network. """ + """ FEDn network interface. This class is used to interact with the network. + Note: This class contain redundant code, which is not used in the current version of FEDn. + Some methods has been moved to :class:`fedn.network.api.interface.API`. + """ def __init__(self, control, statestore, load_balancer=None): """ """ @@ -20,10 +25,12 @@ def __init__(self, control, statestore, load_balancer=None): self.load_balancer = load_balancer def get_combiner(self, name): - """ + """ Get combiner by name. - :param name: - :return: + :param name: name of combiner + :type name: str + :return: The combiner instance object + :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` """ combiners = self.get_combiners() for combiner in combiners: @@ -32,13 +39,14 @@ def get_combiner(self, name): return None def get_combiners(self): - """ + """ Get all combiners in the network. - :return: + :return: list of combiners objects + :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) """ data = self.statestore.get_combiners() combiners = [] - for c in data: + for c in data["result"]: if c['certificate']: cert = base64.b64decode(c['certificate']) key = base64.b64decode(c['key']) @@ -53,10 +61,11 @@ def get_combiners(self): return combiners def add_combiner(self, combiner): - """ + """ Add a new combiner to the network. - :param combiner: - :return: + :param combiner: The combiner instance object + :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` + :return: None """ if not self.control.idle(): print("Reducer is not idle, cannot add additional combiner.") @@ -69,10 +78,11 @@ def add_combiner(self, combiner): self.statestore.set_combiner(combiner.to_dict()) def remove_combiner(self, combiner): - """ + """ Remove a combiner from the network. - :param combiner: - :return: + :param combiner: The combiner instance object + :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` + :return: None """ if not self.control.idle(): print("Reducer is not idle, cannot remove combiner.") @@ -80,15 +90,21 @@ def remove_combiner(self, combiner): self.statestore.delete_combiner(combiner.name) def find_available_combiner(self): - """ + """ Find an available combiner in the network. - :return: + :return: The combiner instance object + :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` """ combiner = self.load_balancer.find_combiner() return combiner def handle_unavailable_combiner(self, combiner): - """ This callback is triggered if a combiner is found to be unresponsive. """ + """ This callback is triggered if a combiner is found to be unresponsive. + + :param combiner: The combiner instance object + :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` + :return: None + """ # TODO: Implement strategy to handle an unavailable combiner. print("REDUCER CONTROL: Combiner {} unavailable.".format( combiner.name), flush=True) @@ -96,8 +112,9 @@ def handle_unavailable_combiner(self, combiner): def add_client(self, client): """ Add a new client to the network. - :param client: - :return: + :param client: The client instance object + :type client: dict + :return: None """ if self.get_client(client['name']): @@ -107,24 +124,43 @@ def add_client(self, client): self.statestore.set_client(client) def get_client(self, name): - """ + """ Get client by name. - :param name: - :return: + :param name: name of client + :type name: str + :return: The client instance object + :rtype: ObjectId """ ret = self.statestore.get_client(name) return ret def update_client_data(self, client_data, status, role): - """ Update client status on DB""" + """ Update client status in statestore. + + :param client_data: The client instance object + :type client_data: dict + :param status: The client status + :type status: str + :param role: The client role + :type role: str + :return: None + """ self.statestore.update_client_status(client_data, status, role) def get_client_info(self): - """ list available client in DB""" + """ list available client in statestore. + + :return: list of client objects + :rtype: list(ObjectId) + """ return self.statestore.list_clients() def describe(self): - """ """ + """ Describe the network. + + :return: The network description + :rtype: dict + """ network = [] for combiner in self.get_combiners(): try: @@ -133,7 +169,3 @@ def describe(self): # TODO, do better here. pass return network - - def check_health(self): - """ """ - pass diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py new file mode 100644 index 000000000..cfb91bece --- /dev/null +++ b/fedn/fedn/network/api/server.py @@ -0,0 +1,388 @@ +from flask import Flask, jsonify, request + +from fedn.common.config import (get_controller_config, get_modelstorage_config, + get_network_config, get_statestore_config) +from fedn.network.api.interface import API +from fedn.network.controller.control import Control +from fedn.network.statestore.mongostatestore import MongoStateStore + +statestore_config = get_statestore_config() +network_id = get_network_config() +modelstorage_config = get_modelstorage_config() +statestore = MongoStateStore( + network_id, statestore_config["mongo_config"], modelstorage_config +) +control = Control(statestore=statestore) +api = API(statestore, control) +app = Flask(__name__) + + +@app.route("/get_model_trail", methods=["GET"]) +def get_model_trail(): + """Get the model trail for a given session. + param: session: The session id to get the model trail for. + type: session: str + return: The model trail for the given session as a json object. + rtype: json + """ + return api.get_model_trail() + + +@app.route("/list_models", methods=["GET"]) +def list_models(): + """Get models from the statestore. + param: + session_id: The session id to get the model trail for. + limit: The maximum number of models to return. + type: limit: int + param: skip: The number of models to skip. + type: skip: int + Returns: + _type_: json + """ + + session_id = request.args.get("session_id", None) + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_models(session_id, limit, skip) + + +@app.route("/delete_model_trail", methods=["GET", "POST"]) +def delete_model_trail(): + """Delete the model trail for a given session. + param: session: The session id to delete the model trail for. + type: session: str + return: The response from the statestore. + rtype: json + """ + return jsonify({"message": "Not implemented"}), 501 + + +@app.route("/list_clients", methods=["GET"]) +def list_clients(): + """Get all clients from the statestore. + return: All clients as a json object. + rtype: json + """ + + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + status = request.args.get("status", None) + + return api.get_clients(limit, skip, status) + + +@app.route("/get_active_clients", methods=["GET"]) +def get_active_clients(): + """Get all active clients from the statestore. + param: combiner_id: The combiner id to get active clients for. + type: combiner_id: str + return: All active clients as a json object. + rtype: json + """ + combiner_id = request.args.get("combiner", None) + if combiner_id is None: + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) + return api.get_active_clients(combiner_id) + + +@app.route("/list_combiners", methods=["GET"]) +def list_combiners(): + """Get all combiners in the network. + return: All combiners as a json object. + rtype: json + """ + + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_combiners(limit, skip) + + +@app.route("/get_combiner", methods=["GET"]) +def get_combiner(): + """Get a combiner from the statestore. + param: combiner_id: The combiner id to get. + type: combiner_id: str + return: The combiner as a json object. + rtype: json + """ + combiner_id = request.args.get("combiner", None) + if combiner_id is None: + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) + return api.get_combiner(combiner_id) + + +@app.route("/list_rounds", methods=["GET"]) +def list_rounds(): + """Get all rounds from the statestore. + return: All rounds as a json object. + rtype: json + """ + return api.get_all_rounds() + + +@app.route("/get_round", methods=["GET"]) +def get_round(): + """Get a round from the statestore. + param: round_id: The round id to get. + type: round_id: str + return: The round as a json object. + rtype: json + """ + round_id = request.args.get("round_id", None) + if round_id is None: + return jsonify({"success": False, "message": "Missing round id."}), 400 + return api.get_round(round_id) + + +@app.route("/start_session", methods=["GET", "POST"]) +def start_session(): + """Start a new session. + return: The response from control. + rtype: json + """ + json_data = request.get_json() + return api.start_session(**json_data) + + +@app.route("/list_sessions", methods=["GET"]) +def list_sessions(): + """Get all sessions from the statestore. + return: All sessions as a json object. + rtype: json + """ + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_sessions(limit, skip) + + +@app.route("/get_session", methods=["GET"]) +def get_session(): + """Get a session from the statestore. + param: session_id: The session id to get. + type: session_id: str + return: The session as a json object. + rtype: json + """ + session_id = request.args.get("session_id", None) + if session_id is None: + return ( + jsonify({"success": False, "message": "Missing session id."}), + 400, + ) + return api.get_session(session_id) + + +@app.route("/set_package", methods=["POST"]) +def set_package(): + """ Set the compute package in the statestore. + Usage with curl: + curl -k -X POST \ + -F file=@package.tgz \ + -F helper="kerashelper" \ + http://localhost:8092/set_package + + param: file: The compute package file to set. + type: file: file + return: The response from the statestore. + rtype: json + """ + helper_type = request.form.get("helper", None) + if helper_type is None: + return ( + jsonify({"success": False, "message": "Missing helper type."}), + 400, + ) + try: + file = request.files["file"] + except KeyError: + return jsonify({"success": False, "message": "Missing file."}), 400 + return api.set_compute_package(file=file, helper_type=helper_type) + + +@app.route("/get_package", methods=["GET"]) +def get_package(): + """Get the compute package from the statestore. + return: The compute package as a json object. + rtype: json + """ + return api.get_compute_package() + + +@app.route("/download_package", methods=["GET"]) +def download_package(): + """Download the compute package. + return: The compute package as a json object. + rtype: json + """ + name = request.args.get("name", None) + return api.download_compute_package(name) + + +@app.route("/get_package_checksum", methods=["GET"]) +def get_package_checksum(): + name = request.args.get("name", None) + return api.get_checksum(name) + + +@app.route("/get_latest_model", methods=["GET"]) +def get_latest_model(): + """Get the latest model from the statestore. + return: The initial model as a json object. + rtype: json + """ + return api.get_latest_model() + + +# Get initial model endpoint + + +@app.route("/get_initial_model", methods=["GET"]) +def get_initial_model(): + """Get the initial model from the statestore. + return: The initial model as a json object. + rtype: json + """ + return api.get_initial_model() + + +@app.route("/set_initial_model", methods=["POST"]) +def set_initial_model(): + """Set the initial model in the statestore and upload to model repository. + Usage with curl: + curl -k -X POST + -F file=@seed.npz + http://localhost:8092/set_initial_model + + param: file: The initial model file to set. + type: file: file + return: The response from the statestore. + rtype: json + """ + try: + file = request.files["file"] + except KeyError: + return jsonify({"success": False, "message": "Missing file."}), 400 + return api.set_initial_model(file) + + +@app.route("/get_controller_status", methods=["GET"]) +def get_controller_status(): + """Get the status of the controller. + return: The status as a json object. + rtype: json + """ + return api.get_controller_status() + + +@app.route("/get_client_config", methods=["GET"]) +def get_client_config(): + """Get the client configuration. + return: The client configuration as a json object. + rtype: json + """ + checksum = request.args.get("checksum", True) + return api.get_client_config(checksum) + + +@app.route("/get_events", methods=["GET"]) +def get_events(): + """Get the events from the statestore. + return: The events as a json object. + rtype: json + """ + # TODO: except filter with request.get_json() + kwargs = request.args.to_dict() + + return api.get_events(**kwargs) + + +@app.route("/list_validations", methods=["GET"]) +def list_validations(): + """Get all validations from the statestore. + return: All validations as a json object. + rtype: json + """ + # TODO: except filter with request.get_json() + kwargs = request.args.to_dict() + return api.get_all_validations(**kwargs) + + +@app.route("/add_combiner", methods=["POST"]) +def add_combiner(): + """Add a combiner to the network. + return: The response from the statestore. + rtype: json + """ + json_data = request.get_json() + remote_addr = request.remote_addr + try: + response = api.add_combiner(**json_data, remote_addr=remote_addr) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/add_client", methods=["POST"]) +def add_client(): + """Add a client to the network. + return: The response from control. + rtype: json + """ + + json_data = request.get_json() + remote_addr = request.remote_addr + try: + response = api.add_client(**json_data, remote_addr=remote_addr) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/list_combiners_data", methods=["POST"]) +def list_combiners_data(): + """List data from combiners. + return: The response from control. + rtype: json + """ + + json_data = request.get_json() + + # expects a list of combiner names (strings) in an array + combiners = json_data.get("combiners", None) + + try: + response = api.list_combiners_data(combiners) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/get_plot_data", methods=["GET"]) +def get_plot_data(): + """Get plot data from the statestore. + rtype: json + """ + + try: + feature = request.args.get("feature", None) + response = api.get_plot_data(feature=feature) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +if __name__ == "__main__": + config = get_controller_config() + port = config["port"] + debug = config["debug"] + app.run(debug=debug, port=port, host="0.0.0.0") diff --git a/fedn/fedn/network/api/tests.py b/fedn/fedn/network/api/tests.py new file mode 100644 index 000000000..7395d9bdf --- /dev/null +++ b/fedn/fedn/network/api/tests.py @@ -0,0 +1,335 @@ +# Unittest for Flask API endpoints +# +# Run with: +# python -m unittest fedn.tests.network.api.tests +# +# or +# +# python3 -m unittest fedn.tests.network.api.tests +# +# or +# +# python3 -m unittest fedn.tests.network.api.tests.NetworkAPITests +# +# or +# +# python -m unittest fedn.tests.network.api.tests.NetworkAPITests +# +# or +# +# python -m unittest fedn.tests.network.api.tests.NetworkAPITests.test_get_model_trail +# +# or +# +# python3 -m unittest fedn.tests.network.api.tests.NetworkAPITests.test_get_model_trail +# + +import io +import time +import unittest +from unittest.mock import MagicMock, patch + +import fedn + + +class NetworkAPITests(unittest.TestCase): + """ Unittests for the Network API. """ + @patch('fedn.network.statestore.mongostatestore.MongoStateStore', autospec=True) + @patch('fedn.network.controller.controlbase.ControlBase', autospec=True) + def setUp(self, mock_mongo, mock_control): + # start Flask server in testing mode + import fedn.network.api.server + self.app = fedn.network.api.server.app.test_client() + + def test_get_model_trail(self): + """ Test get_model_trail endpoint. """ + # Mock api.get_model_trail + model_id = "test" + time_stamp = time.time() + return_value = {model_id: time_stamp} + fedn.network.api.server.api.get_model_trail = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_model_trail') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_model_trail was called + fedn.network.api.server.api.get_model_trail.assert_called_once_with() + + def test_get_latest_model(self): + """ Test get_latest_model endpoint. """ + # Mock api.get_latest_model + model_id = "test" + time_stamp = time.time() + return_value = {model_id: time_stamp} + fedn.network.api.server.api.get_latest_model = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_latest_model') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_latest_model was called + fedn.network.api.server.api.get_latest_model.assert_called_once_with() + + def test_get_initial_model(self): + """ Test get_initial_model endpoint. """ + # Mock api.get_initial_model + model_id = "test" + time_stamp = time.time() + return_value = {model_id: time_stamp} + fedn.network.api.server.api.get_initial_model = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_initial_model') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_initial_model was called + fedn.network.api.server.api.get_initial_model.assert_called_once_with() + + def test_set_initial_model(self): + """ Test set_initial_model endpoint. """ + # Mock api.set_initial_model + success = True + message = "test" + return_value = {'success': success, 'message': message} + fedn.network.api.server.api.set_initial_model = MagicMock(return_value=return_value) + # Create test file + request_file = (io.BytesIO(b"abcdef"), 'test.txt') + # Make request + response = self.app.post('/set_initial_model', data={"file": request_file}) + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.set_initial_model was called + fedn.network.api.server.api.set_initial_model.assert_called_once() + + def test_list_clients(self): + """ Test list_clients endpoint. """ + # Mock api.get_all_clients + return_value = {"test": "test"} + fedn.network.api.server.api.get_all_clients = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/list_clients') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_all_clients was called + fedn.network.api.server.api.get_all_clients.assert_called_once_with() + + def test_get_active_clients(self): + """ Test get_active_clients endpoint. """ + # Mock api.get_active_clients + return_value = {"test": "test"} + fedn.network.api.server.api.get_active_clients = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_active_clients?combiner=test') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_active_clients was called + fedn.network.api.server.api.get_active_clients.assert_called_once_with("test") + + def test_add_client(self): + """ Test add_client endpoint. """ + # Mock api.add_client + return_value = {"test": "test"} + fedn.network.api.server.api.add_client = MagicMock(return_value=return_value) + # Make request + response = self.app.post('/add_client', json={ + 'preferred_combiner': 'test', + }) + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.add_client was called + fedn.network.api.server.api.add_client.assert_called_once_with( + preferred_combiner="test", + remote_addr='127.0.0.1' + ) + + def test_list_combiners(self): + """ Test list_combiners endpoint. """ + # Mock api.get_all_combiners + return_value = {"test": "test"} + fedn.network.api.server.api.get_all_combiners = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/list_combiners') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_all_combiners was called + fedn.network.api.server.api.get_all_combiners.assert_called_once_with() + + def test_list_rounds(self): + """ Test list_rounds endpoint. """ + # Mock api.get_all_rounds + return_value = {"test": "test"} + fedn.network.api.server.api.get_all_rounds = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/list_rounds') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_all_rounds was called + fedn.network.api.server.api.get_all_rounds.assert_called_once_with() + + def test_get_round(self): + """ Test get_round endpoint. """ + # Mock api.get_round + return_value = {"test": "test"} + fedn.network.api.server.api.get_round = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_round?round_id=test') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_round was called + fedn.network.api.server.api.get_round.assert_called_once_with("test") + + def test_get_combiner(self): + """ Test get_combiner endpoint. """ + # Mock api.get_combiner + return_value = {"test": "test"} + fedn.network.api.server.api.get_combiner = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_combiner?combiner=test') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_combiner was called + fedn.network.api.server.api.get_combiner.assert_called_once_with("test") + + def test_add_combiner(self): + """ Test add_combiner endpoint. """ + # Mock api.add_combiner + success = True + message = "test" + return_value = {'success': success, 'message': message} + fedn.network.api.server.api.add_combiner = MagicMock(return_value=return_value) + # Make request + response = self.app.post('/add_combiner', json={ + 'combiner_id': 'test', + 'address': '1234', + 'port': '1234', + 'secure_grpc': 'True', + 'fqdn': 'test', + }) + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.add_combiner was called + fedn.network.api.server.api.add_combiner.assert_called_once_with( + combiner_id='test', + remote_addr='127.0.0.1', + address='1234', + port='1234', + secure_grpc='True', + fqdn='test', + ) + + def test_get_events(self): + """ Test get_events endpoint. """ + # Mock api.get_events + return_value = {"test": "test"} + fedn.network.api.server.api.get_events = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_events') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_events was called + fedn.network.api.server.api.get_events.assert_called_once() + + def test_get_status(self): + """ Test get_status endpoint. """ + # Mock api.get_status + return_value = {"test": "test"} + fedn.network.api.server.api.get_controller_status = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_controller_status') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_status was called + fedn.network.api.server.api.get_controller_status.assert_called_once() + + def test_start_session(self): + """ Test start_session endpoint. """ + # Mock api.start_session + success = True + message = "test" + return_value = {'success': success, 'message': message} + fedn.network.api.server.api.start_session = MagicMock(return_value=return_value) + # Make request with only session_id + json = {'session_id': 'test', + 'round_timeout': float(60), + 'rounds': 1, + 'round_buffer_size': -1, + } + response = self.app.post('/start_session', json=json) + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.start_session was called + fedn.network.api.server.api.start_session.assert_called_once_with( + session_id='test', + round_timeout=float(60), + rounds=1, + round_buffer_size=-1, + ) + + def test_list_sessions(self): + """ Test list_sessions endpoint. """ + # Mock api.list_sessions + return_value = {"test": "test"} + fedn.network.api.server.api.get_all_sessions = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/list_sessions') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.list_sessions was called + fedn.network.api.server.api.get_all_sessions.assert_called_once() + + def test_get_package(self): + """ Test get_package endpoint. """ + # Mock api.get_package + return_value = {"test": "test"} + fedn.network.api.server.api.get_compute_package = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_package') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_package was called + fedn.network.api.server.api.get_compute_package.assert_called_once_with() + + def test_get_controller_status(self): + """ Test get_controller_status endpoint. """ + # Mock api.get_controller_status + return_value = {"test": "test"} + fedn.network.api.server.api.get_controller_status = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_controller_status') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_controller_status was called + fedn.network.api.server.api.get_controller_status.assert_called_once_with() + + def test_get_client_config(self): + """ Test get_client_config endpoint. """ + # Mock api.get_client_config + return_value = {"test": "test"} + fedn.network.api.server.api.get_client_config = MagicMock(return_value=return_value) + # Make request + response = self.app.get('/get_client_config') + # Assert response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, return_value) + # Assert api.get_client_config was called + fedn.network.api.server.api.get_client_config.assert_called_once_with(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/fedn/fedn/network/clients/__init__.py b/fedn/fedn/network/clients/__init__.py index e69de29bb..effcee624 100644 --- a/fedn/fedn/network/clients/__init__.py +++ b/fedn/fedn/network/clients/__init__.py @@ -0,0 +1,4 @@ +""" The FEDn client package is responsible for executing the federated learning tasks, including ML model training and validation. It's the acting gRPC client for the federated network. +The client first connacts the centralized controller to receive :class:`fedn.network.combiner.Combiner` assingment. The client then connects to the combiner and +sends requests to the combiner to receive model updates and send model updates.""" +# flake8: noqa diff --git a/fedn/fedn/network/clients/client.py b/fedn/fedn/network/clients/client.py index b2d358112..9851b32ef 100644 --- a/fedn/fedn/network/clients/client.py +++ b/fedn/fedn/network/clients/client.py @@ -15,14 +15,12 @@ from io import BytesIO import grpc -from flask import Flask from google.protobuf.json_format import MessageToJson import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc -from fedn.common.control.package import PackageRuntime -from fedn.common.net.connect import ConnectorClient, Status -from fedn.common.net.web.client import page, style +from fedn.network.clients.connect import ConnectorClient, Status +from fedn.network.clients.package import PackageRuntime from fedn.network.clients.state import ClientState, ClientStateToString from fedn.utils.dispatcher import Dispatcher from fedn.utils.helpers import get_helper @@ -42,25 +40,15 @@ def __call__(self, context, callback): class Client: """FEDn Client. Service running on client/datanodes in a federation, - recieving and handling model update and model validation requests. - - Attibutes - --------- - config: dict - A configuration dictionary containing connection information for - the discovery service (controller) and settings governing e.g. - client-combiner assignment behavior. + recieving and handling model update and model validation requests. + :param config: A configuration dictionary containing connection information for the discovery service (controller) + and settings governing e.g. client-combiner assignment behavior. + :type config: dict """ def __init__(self, config): - """Initialize the client. - - :param config: A configuration dictionary containing connection information for - the discovery service (controller) and settings governing e.g. - client-combiner assignment behavior. - :type config: dict - """ + """Initialize the client.""" self.state = None self.error_state = False @@ -109,6 +97,106 @@ def __init__(self, config): self.state = ClientState.idle + def _assign(self): + """Contacts the controller and asks for combiner assignment. + + :return: A configuration dictionary containing connection information for combiner. + :rtype: dict + """ + + print("Asking for assignment!", flush=True) + while True: + status, response = self.connector.assign() + if status == Status.TryAgain: + print(response, flush=True) + time.sleep(5) + continue + if status == Status.Assigned: + client_config = response + break + if status == Status.UnAuthorized: + print(response, flush=True) + sys.exit("Exiting: Unauthorized") + if status == Status.UnMatchedConfig: + print(response, flush=True) + sys.exit("Exiting: UnMatchedConfig") + time.sleep(5) + print(".", end=' ', flush=True) + + print("Got assigned!", flush=True) + print("Received combiner config: {}".format(client_config), flush=True) + return client_config + + def _connect(self, client_config): + """Connect to assigned combiner. + + :param client_config: A configuration dictionary containing connection information for + the combiner. + :type client_config: dict + """ + + # TODO use the client_config['certificate'] for setting up secure comms' + host = client_config['host'] + port = client_config['port'] + secure = False + if client_config['fqdn'] is not None: + host = client_config['fqdn'] + # assuming https if fqdn is used + port = 443 + print(f"CLIENT: Connecting to combiner host: {host}:{port}", flush=True) + + if client_config['certificate']: + print("CLIENT: using certificate from Reducer for GRPC channel") + secure = True + cert = base64.b64decode( + client_config['certificate']) # .decode('utf-8') + credentials = grpc.ssl_channel_credentials(root_certificates=cert) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): + secure = True + print("CLIENT: using root certificate from environment variable for GRPC channel") + with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: + credentials = grpc.ssl_channel_credentials(f.read()) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + elif self.config['secure']: + secure = True + print("CLIENT: using CA certificate for GRPC channel") + cert = ssl.get_server_certificate((host, port)) + + credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) + if self.config['token']: + token = self.config['token'] + auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) + channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) + else: + channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) + else: + print("CLIENT: using insecure GRPC channel") + if port == 443: + port = 80 + channel = grpc.insecure_channel("{}:{}".format( + host, + str(port))) + + self.channel = channel + + self.connectorStub = rpc.ConnectorStub(channel) + self.combinerStub = rpc.CombinerStub(channel) + self.modelStub = rpc.ModelServiceStub(channel) + + print("Client: {} connected {} to {}:{}".format(self.name, + "SECURED" if secure else "INSECURE", + host, + port), + flush=True) + + print("Client: Using {} compute package.".format( + client_config["package"])) + + def _disconnect(self): + """Disconnect from the combiner.""" + self.channel.close() + def _detach(self): """Detach from the FEDn network (disconnect from combiner)""" # Setting _attached to False will make all processing threads return @@ -137,21 +225,21 @@ def _initialize_helper(self, client_config): """Initialize the helper class for the client. :param client_config: A configuration dictionary containing connection information for - the discovery service (controller) and settings governing e.g. - client-combiner assignment behavior. + | the discovery service (controller) and settings governing e.g. + | client-combiner assignment behavior. :type client_config: dict + :return: """ - if 'model_type' in client_config.keys(): - self.helper = get_helper(client_config['model_type']) + if 'helper_type' in client_config.keys(): + self.helper = get_helper(client_config['helper_type']) def _subscribe_to_combiner(self, config): """Listen to combiner message stream and start all processing threads. :param config: A configuration dictionary containing connection information for - the discovery service (controller) and settings governing e.g. - client-combiner assignment behavior. - + | the discovery service (controller) and settings governing e.g. + | client-combiner assignment behavior. """ # Start sending heartbeats to the combiner. @@ -174,10 +262,10 @@ def _initialize_dispatcher(self, config): """ Initialize the dispatcher for the client. :param config: A configuration dictionary containing connection information for - the discovery service (controller) and settings governing e.g. - client-combiner assignment behavior. + | the discovery service (controller) and settings governing e.g. + | client-combiner assignment behavior. :type config: dict - + :return: """ if config['remote_compute_context']: pr = PackageRuntime(os.getcwd(), os.getcwd()) @@ -232,117 +320,18 @@ def _initialize_dispatcher(self, config): copy_tree(from_path, self.run_path) self.dispatcher = Dispatcher(dispatch_config, self.run_path) - def _assign(self): - """Contacts the controller and asks for combiner assignment. - - :return: A configuration dictionary containing connection information for combiner. - :rtype: dict - """ - - print("Asking for assignment!", flush=True) - while True: - status, response = self.connector.assign() - if status == Status.TryAgain: - print(response, flush=True) - time.sleep(5) - continue - if status == Status.Assigned: - client_config = response - break - if status == Status.UnAuthorized: - print(response, flush=True) - sys.exit("Exiting: Unauthorized") - if status == Status.UnMatchedConfig: - print(response, flush=True) - sys.exit("Exiting: UnMatchedConfig") - time.sleep(5) - print(".", end=' ', flush=True) - - print("Got assigned!", flush=True) - return client_config - - def _connect(self, client_config): - """Connect to assigned combiner. - - :param client_config: A configuration dictionary containing connection information for - the combiner. - :type client_config: dict - """ - - # TODO use the client_config['certificate'] for setting up secure comms' - host = client_config['host'] - port = client_config['port'] - secure = False - if client_config['fqdn'] != "None": - host = client_config['fqdn'] - # assuming https if fqdn is used - port = 443 - print(f"CLIENT: Connecting to combiner host: {host}:{port}", flush=True) - - if client_config['certificate']: - print("CLIENT: using certificate from Reducer for GRPC channel") - secure = True - cert = base64.b64decode( - client_config['certificate']) # .decode('utf-8') - credentials = grpc.ssl_channel_credentials(root_certificates=cert) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - elif os.getenv("FEDN_GRPC_ROOT_CERT_PATH"): - secure = True - print("CLIENT: using root certificate from environment variable for GRPC channel") - with open(os.environ["FEDN_GRPC_ROOT_CERT_PATH"], 'rb') as f: - credentials = grpc.ssl_channel_credentials(f.read()) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - elif self.config['secure']: - secure = True - print("CLIENT: using CA certificate for GRPC channel") - cert = ssl.get_server_certificate((host, port)) - - credentials = grpc.ssl_channel_credentials(cert.encode('utf-8')) - if self.config['token']: - token = self.config['token'] - auth_creds = grpc.metadata_call_credentials(GrpcAuth(token)) - channel = grpc.secure_channel("{}:{}".format(host, str(port)), grpc.composite_channel_credentials(credentials, auth_creds)) - else: - channel = grpc.secure_channel("{}:{}".format(host, str(port)), credentials) - else: - print("CLIENT: using insecure GRPC channel") - if port == 443: - port = 80 - channel = grpc.insecure_channel("{}:{}".format( - host, - str(port))) - - self.channel = channel - - self.connection = rpc.ConnectorStub(channel) - self.orchestrator = rpc.CombinerStub(channel) - self.models = rpc.ModelServiceStub(channel) - - print("Client: {} connected {} to {}:{}".format(self.name, - "SECURED" if secure else "INSECURE", - host, - port), - flush=True) - - print("Client: Using {} compute package.".format( - client_config["package"])) - - def _disconnect(self): - self.channel.close() - def get_model(self, id): """Fetch a model from the assigned combiner. - Downloads the model update object via a gRPC streaming channel, Download. + Downloads the model update object via a gRPC streaming channel. :param id: The id of the model update object. :type id: str :return: The model update object. :rtype: BytesIO - """ data = BytesIO() - for part in self.models.Download(fedn.ModelRequest(id=id)): + for part in self.modelStub.Download(fedn.ModelRequest(id=id)): if part.status == fedn.ModelStatus.IN_PROGRESS: data.write(part.data) @@ -357,7 +346,6 @@ def get_model(self, id): def set_model(self, model, id): """Send a model update to the assigned combiner. - Uploads the model updated object via a gRPC streaming channel, Upload. :param model: The model update object. @@ -398,7 +386,7 @@ def upload_request_generator(mdl): if not b: break - result = self.models.Upload(upload_request_generator(bt)) + result = self.modelStub.Upload(upload_request_generator(bt)) return result @@ -416,7 +404,7 @@ def _listen_to_model_update_request_stream(self): while True: try: - for request in self.orchestrator.ModelUpdateRequestStream(r, metadata=metadata): + for request in self.combinerStub.ModelUpdateRequestStream(r, metadata=metadata): if request.sender.role == fedn.COMBINER: # Process training request self._send_status("Received model update request.", log_level=fedn.Status.AUDIT, @@ -441,7 +429,6 @@ def _listen_to_model_update_request_stream(self): def _listen_to_model_validation_request_stream(self): """Subscribe to the model validation request stream. - :return: None :rtype: None """ @@ -451,7 +438,7 @@ def _listen_to_model_validation_request_stream(self): r.sender.role = fedn.WORKER while True: try: - for request in self.orchestrator.ModelValidationRequestStream(r): + for request in self.combinerStub.ModelValidationRequestStream(r): # Process validation request _ = request.model_id self._send_status("Recieved model validation request.", log_level=fedn.Status.AUDIT, @@ -468,86 +455,6 @@ def _listen_to_model_validation_request_stream(self): if not self._attached: return - def process_request(self): - """Process training and validation tasks. """ - while True: - - if not self._attached: - return - - try: - (task_type, request) = self.inbox.get(timeout=1.0) - if task_type == 'train': - - tic = time.time() - self.state = ClientState.training - model_id, meta = self._process_training_request( - request.model_id) - processing_time = time.time()-tic - meta['processing_time'] = processing_time - meta['config'] = request.data - - if model_id is not None: - # Send model update to combiner - update = fedn.ModelUpdate() - update.sender.name = self.name - update.sender.role = fedn.WORKER - update.receiver.name = request.sender.name - update.receiver.role = request.sender.role - update.model_id = request.model_id - update.model_update_id = str(model_id) - update.timestamp = str(datetime.now()) - update.correlation_id = request.correlation_id - update.meta = json.dumps(meta) - # TODO: Check responses - _ = self.orchestrator.SendModelUpdate(update) - self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, - type=fedn.StatusType.MODEL_UPDATE, request=update) - - else: - self._send_status("Client {} failed to complete model update.", - log_level=fedn.Status.WARNING, - request=request) - self.state = ClientState.idle - self.inbox.task_done() - - elif task_type == 'validate': - self.state = ClientState.validating - metrics = self._process_validation_request( - request.model_id, request.is_inference) - - if metrics is not None: - # Send validation - validation = fedn.ModelValidation() - validation.sender.name = self.name - validation.sender.role = fedn.WORKER - validation.receiver.name = request.sender.name - validation.receiver.role = request.sender.role - validation.model_id = str(request.model_id) - validation.data = json.dumps(metrics) - self.str = str(datetime.now()) - validation.timestamp = self.str - validation.correlation_id = request.correlation_id - _ = self.orchestrator.SendModelValidation( - validation) - - # Set status type - if request.is_inference: - status_type = fedn.StatusType.INFERENCE - else: - status_type = fedn.StatusType.MODEL_VALIDATION - - self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, - type=status_type, request=validation) - else: - self._send_status("Client {} failed to complete model validation.".format(self.name), - log_level=fedn.Status.WARNING, request=request) - - self.state = ClientState.idle - self.inbox.task_done() - except queue.Empty: - pass - def _process_training_request(self, model_id): """Process a training (model update) request. @@ -555,7 +462,6 @@ def _process_training_request(self, model_id): :type model_id: str :return: The model id of the updated model, or None if the update failed. And a dict with metadata. :rtype: tuple - """ self._send_status( @@ -651,6 +557,86 @@ def _process_validation_request(self, model_id, is_inference): self.state = ClientState.idle return validation + def process_request(self): + """Process training and validation tasks. """ + while True: + + if not self._attached: + return + + try: + (task_type, request) = self.inbox.get(timeout=1.0) + if task_type == 'train': + + tic = time.time() + self.state = ClientState.training + model_id, meta = self._process_training_request( + request.model_id) + processing_time = time.time()-tic + meta['processing_time'] = processing_time + meta['config'] = request.data + + if model_id is not None: + # Send model update to combiner + update = fedn.ModelUpdate() + update.sender.name = self.name + update.sender.role = fedn.WORKER + update.receiver.name = request.sender.name + update.receiver.role = request.sender.role + update.model_id = request.model_id + update.model_update_id = str(model_id) + update.timestamp = str(datetime.now()) + update.correlation_id = request.correlation_id + update.meta = json.dumps(meta) + # TODO: Check responses + _ = self.combinerStub.SendModelUpdate(update) + self._send_status("Model update completed.", log_level=fedn.Status.AUDIT, + type=fedn.StatusType.MODEL_UPDATE, request=update) + + else: + self._send_status("Client {} failed to complete model update.", + log_level=fedn.Status.WARNING, + request=request) + self.state = ClientState.idle + self.inbox.task_done() + + elif task_type == 'validate': + self.state = ClientState.validating + metrics = self._process_validation_request( + request.model_id, request.is_inference) + + if metrics is not None: + # Send validation + validation = fedn.ModelValidation() + validation.sender.name = self.name + validation.sender.role = fedn.WORKER + validation.receiver.name = request.sender.name + validation.receiver.role = request.sender.role + validation.model_id = str(request.model_id) + validation.data = json.dumps(metrics) + self.str = str(datetime.now()) + validation.timestamp = self.str + validation.correlation_id = request.correlation_id + _ = self.combinerStub.SendModelValidation( + validation) + + # Set status type + if request.is_inference: + status_type = fedn.StatusType.INFERENCE + else: + status_type = fedn.StatusType.MODEL_VALIDATION + + self._send_status("Model validation completed.", log_level=fedn.Status.AUDIT, + type=status_type, request=validation) + else: + self._send_status("Client {} failed to complete model validation.".format(self.name), + log_level=fedn.Status.WARNING, request=request) + + self.state = ClientState.idle + self.inbox.task_done() + except queue.Empty: + pass + def _handle_combiner_failure(self): """ Register failed combiner connection.""" self._missed_heartbeat += 1 @@ -665,12 +651,11 @@ def _send_heartbeat(self, update_frequency=2.0): :return: None if the client is detached. :rtype: None """ - while True: heartbeat = fedn.Heartbeat(sender=fedn.Client( name=self.name, role=fedn.WORKER)) try: - self.connection.SendHeartbeat(heartbeat) + self.connectorStub.SendHeartbeat(heartbeat) self._missed_heartbeat = 0 except grpc.RpcError as e: status_code = e.code() @@ -709,33 +694,7 @@ def _send_status(self, msg, log_level=fedn.Status.INFO, type=None, request=None) self.logs.append( "{} {} LOG LEVEL {} MESSAGE {}".format(str(datetime.now()), status.sender.name, status.log_level, status.status)) - _ = self.connection.SendStatus(status) - - def run_web(self): - """Starts a local logging UI (Flask app) serving on port 8080. - - Currently not in use. - - """ - app = Flask(__name__) - - @ app.route('/') - def index(): - """ - - :return: - """ - logs_fancy = str() - for log in self.logs: - logs_fancy += "

" + log + "

\n" - - return page.format(client=self.name, state=ClientStateToString(self.state), style=style, logs=logs_fancy) - - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, 'w') - app.run(host="0.0.0.0", port="8080") - sys.stdout.close() - sys.stdout = self._original_stdout + _ = self.connectorStub.SendStatus(status) def run(self): """ Run the client. """ diff --git a/fedn/fedn/network/clients/connect.py b/fedn/fedn/network/clients/connect.py new file mode 100644 index 000000000..2f8acfa8d --- /dev/null +++ b/fedn/fedn/network/clients/connect.py @@ -0,0 +1,116 @@ +# This file contains the Connector class for assigning client to the FEDn network via the discovery service (REST-API). +# The Connector class is used by the Client class in fedn/network/clients/client.py. +# Once assigned, the client will retrieve combiner assignment from the discovery service. +# The discovery service will also add the client to the statestore. +# +# +import enum + +import requests + + +class Status(enum.Enum): + """ Enum for representing the status of a client assignment.""" + Unassigned = 0 + Assigned = 1 + TryAgain = 2 + UnAuthorized = 3 + UnMatchedConfig = 4 + + +class ConnectorClient: + """ Connector for assigning client to a combiner in the FEDn network. + + :param host: host of discovery service + :type host: str + :param port: port of discovery service + :type port: int + :param token: token for authentication + :type token: str + :param name: name of client + :type name: str + :param remote_package: True if remote package is used, False if local + :type remote_package: bool + :param force_ssl: True if https is used, False if http + :type force_ssl: bool + :param verify: True if certificate is verified, False if not + :type verify: bool + :param combiner: name of preferred combiner + :type combiner: str + :param id: id of client + """ + + def __init__(self, host, port, token, name, remote_package, force_ssl=False, verify=False, combiner=None, id=None): + + self.host = host + self.port = port + self.token = token + self.name = name + self.verify = verify + self.preferred_combiner = combiner + self.id = id + self.package = 'remote' if remote_package else 'local' + + # for https we assume a an ingress handles permanent redirect (308) + if force_ssl: + self.prefix = "https://" + else: + self.prefix = "http://" + if self.port: + self.connect_string = "{}{}:{}".format( + self.prefix, self.host, self.port) + else: + self.connect_string = "{}{}".format( + self.prefix, self.host) + + print("\n\nsetting the connection string to {}\n\n".format( + self.connect_string), flush=True) + + def assign(self): + """ + Connect client to FEDn network discovery service, ask for combiner assignment. + + :return: Tuple with assingment status, combiner connection information if sucessful, else None. + :rtype: tuple(:class:`fedn.network.clients.connect.Status`, str) + """ + try: + retval = None + payload = {'client_id': self.name, 'preferred_combiner': self.preferred_combiner} + + retval = requests.post(self.connect_string + '/add_client', + json=payload, + verify=self.verify, + allow_redirects=True, + headers={'Authorization': 'Token {}'.format(self.token)}) + except Exception as e: + print('***** {}'.format(e), flush=True) + return Status.Unassigned, {} + + if retval.status_code == 400: + # Get error messange from response + reason = retval.json()['message'] + return Status.UnMatchedConfig, reason + + if retval.status_code == 401: + reason = "Unauthorized connection to reducer, make sure the correct token is set" + return Status.UnAuthorized, reason + + if retval.status_code >= 200 and retval.status_code < 204: + if retval.json()['status'] == 'retry': + if 'message' in retval.json(): + reason = retval.json()['message'] + else: + reason = "Reducer was not ready. Try again later." + + return Status.TryAgain, reason + + reducer_package = retval.json()['package'] + if reducer_package != self.package: + reason = "Unmatched config of compute package between client and reducer.\n" +\ + "Reducer uses {} package and client uses {}.".format( + reducer_package, self.package) + return Status.UnMatchedConfig, reason + + return Status.Assigned, retval.json() + + return Status.Unassigned, None diff --git a/fedn/fedn/common/control/package.py b/fedn/fedn/network/clients/package.py similarity index 54% rename from fedn/fedn/common/control/package.py rename to fedn/fedn/network/clients/package.py index b7f3a3471..d6c91ccba 100644 --- a/fedn/fedn/common/control/package.py +++ b/fedn/fedn/network/clients/package.py @@ -1,5 +1,7 @@ +# This file contains the PackageRuntime class, which is used to download, validate and unpack compute packages. +# +# import cgi -import hashlib import os import tarfile from distutils.dir_util import copy_tree @@ -11,91 +13,13 @@ from fedn.utils.dispatcher import Dispatcher -class Package: - """ - - """ - - def __init__(self, config): - self.config = config - self.name = config['name'] - self.cwd = config['cwd'] - if 'port' in config: - self.reducer_port = config['port'] - if 'host' in config: - self.reducer_host = config['host'] - if 'token' in config: - self.reducer_token = config['token'] - - self.package_file = None - self.file_path = None - self.package_hash = None - - def package(self, validate=False): - """ - - :param validate: - :return: - """ - # check config - package_file = '{name}.tar.gz'.format(name=self.name) - - # package the file - cwd = os.getcwd() - self.file_path = os.getcwd() - if self.config['cwd'] == '': - self.file_path = os.getcwd() - os.chdir(self.file_path) - with tarfile.open(os.path.join(os.path.dirname(self.file_path), package_file), 'w:gz') as tf: - # for walking the current dir with absolute path (in archive) - # for root, dirs, files in os.walk(self.file_path): - # for file in files: - # tf.add(os.path.join(root, file)) - # for walking the current dir - for file in os.listdir(self.file_path): - tf.add(file) - tf.close() - - hsh = hashlib.sha256() - with open(os.path.join(os.path.dirname(self.file_path), package_file), 'rb') as f: - for byte_block in iter(lambda: f.read(4096), b""): - hsh.update(byte_block) - - os.chdir(cwd) - self.package_file = package_file - self.package_hash = hsh.hexdigest() - - return package_file, hsh.hexdigest() - - def upload(self): - """ - - """ - if self.package_file: - # data = {'name': self.package_file, 'hash': str(self.package_hash)} - # print("going to send {}".format(data),flush=True) - f = open(os.path.join(os.path.dirname( - self.file_path), self.package_file), 'rb') - print("Sending the following file {}".format(f.read()), flush=True) - f.seek(0, 0) - files = {'file': f} - try: - requests.post('https://{}:{}/context'.format(self.reducer_host, self.reducer_port), - verify=False, files=files, - # data=data, - headers={'Authorization': 'Token {}'.format(self.reducer_token)}) - except Exception as e: - print("failed to put execution context to reducer. {}".format( - e), flush=True) - finally: - f.close() - - print("Upload 4 ", flush=True) - - class PackageRuntime: - """ + """ PackageRuntime is used to download, validate and unpack compute packages. + :param package_path: path to compute package + :type package_path: str + :param package_dir: directory to unpack compute package + :type package_dir: str """ def __init__(self, package_path, package_dir): @@ -112,14 +36,14 @@ def __init__(self, package_path, package_dir): self.expected_checksum = None def download(self, host, port, token, force_ssl=False, secure=False, name=None): - """ - Download compute package from controller - - :param host: - :param port: - :param token: - :param name: - :return: + """ Download compute package from controller + + :param host: host of controller + :param port: port of controller + :param token: token for authentication + :param name: name of package + :return: True if download was successful, None otherwise + :rtype: bool """ # for https we assume a an ingress handles permanent redirect (308) if force_ssl: @@ -127,9 +51,9 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): else: scheme = "http" if port: - path = f"{scheme}://{host}:{port}/context" + path = f"{scheme}://{host}:{port}/download_package" else: - path = f"{scheme}://{host}/context" + path = f"{scheme}://{host}/download_package" if name: path = path + "?name={}".format(name) @@ -148,9 +72,9 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): for chunk in r.iter_content(chunk_size=8192): f.write(chunk) if port: - path = f"{scheme}://{host}:{port}/checksum" + path = f"{scheme}://{host}:{port}/get_package_checksum" else: - path = f"{scheme}://{host}/checksum" + path = f"{scheme}://{host}/get_package_checksum" if name: path = path + "?name={}".format(name) @@ -166,22 +90,17 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None): return True def validate(self, expected_checksum): - """ + """ Validate the package against the checksum provided by the controller - :param expected_checksum: - :return: + :param expected_checksum: checksum provided by the controller + :return: True if checksums match, False otherwise + :rtype: bool """ self.expected_checksum = expected_checksum # crosscheck checksum and unpack if security checks are ok. - # print("check if checksum {} is equal to checksum expected {}".format(self.checksum,self.expected_checksum)) file_checksum = str(sha(os.path.join(self.pkg_path, self.pkg_name))) - # catched by client, make configurable by governance network! - # if self.expected_checksum is None: - # print("CAUTION: Package validation turned OFF on client", flush=True) - # return True - if self.checksum == self.expected_checksum == file_checksum: print("Package validated {}".format(self.checksum)) return True @@ -189,8 +108,10 @@ def validate(self, expected_checksum): return False def unpack(self): - """ + """ Unpack the compute package + :return: True if unpacking was successful, False otherwise + :rtype: bool """ if self.pkg_name: f = None @@ -205,7 +126,10 @@ def unpack(self): self.pkg_path, self.pkg_name), 'r:bz2') else: print( - "Failed to unpack compute package, no pkg_name set. Has the reducer been configured with a compute package?") + "Failed to unpack compute package, no pkg_name set." + "Has the reducer been configured with a compute package?" + ) + return False os.getcwd() try: @@ -215,14 +139,18 @@ def unpack(self): f.extractall() print("Successfully extracted compute package content in {}".format( self.dir), flush=True) + return True except Exception: print("Error extracting files!") + return False def dispatcher(self, run_path): - """ + """ Dispatch the compute package - :param run_path: - :return: + :param run_path: path to dispatch the compute package + :type run_path: str + :return: Dispatcher object + :rtype: :class:`fedn.utils.dispatcher.Dispatcher` """ from_path = os.path.join(os.getcwd(), 'client') diff --git a/fedn/fedn/network/clients/state.py b/fedn/fedn/network/clients/state.py index 2afd85115..262f5862e 100644 --- a/fedn/fedn/network/clients/state.py +++ b/fedn/fedn/network/clients/state.py @@ -2,6 +2,7 @@ class ClientState(Enum): + """ Enum for representing the state of a client.""" idle = 1 training = 2 validating = 3 diff --git a/fedn/fedn/network/combiner/__init__.py b/fedn/fedn/network/combiner/__init__.py index e69de29bb..462f91393 100644 --- a/fedn/fedn/network/combiner/__init__.py +++ b/fedn/fedn/network/combiner/__init__.py @@ -0,0 +1 @@ +""" The FEDn Combiner package is responsible for combining models from multiple clients. It's the acting gRPC server for the federated network.""" diff --git a/fedn/fedn/network/combiner/aggregators/__init__.py b/fedn/fedn/network/combiner/aggregators/__init__.py index e69de29bb..cb7ee83d6 100644 --- a/fedn/fedn/network/combiner/aggregators/__init__.py +++ b/fedn/fedn/network/combiner/aggregators/__init__.py @@ -0,0 +1,3 @@ +""" The aggregator package is responsible for aggregating models from multiple clients. It's called both in :class:`fedn.network.combiner.Combiner` and :class:`fedn.network.controller.Controller` +to aggregate models from clients. """ +# flake8: noqa diff --git a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/fedn/network/combiner/aggregators/aggregatorbase.py index e075f7142..bcbb699e2 100644 --- a/fedn/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/fedn/network/combiner/aggregators/aggregatorbase.py @@ -9,23 +9,23 @@ class AggregatorBase(ABC): - """ Abstract class defining an aggregator. """ + """ Abstract class defining an aggregator. + + :param id: A reference to id of :class: `fedn.network.combiner.Combiner` + :type id: str + :param storage: Model repository for :class: `fedn.network.combiner.Combiner` + :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` + :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` + :type server: class: `fedn.network.combiner.Combiner` + :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` + :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` + :param control: A handle to the :class: `fedn.network.combiner.round.RoundController` + :type control: class: `fedn.network.combiner.round.RoundController` + """ @abstractmethod def __init__(self, storage, server, modelservice, control): - """ Initialize the aggregator. - - :param id: A reference to id of :class: `fedn.network.combiner.Combiner` - :type id: str - :param storage: Model repository for :class: `fedn.network.combiner.Combiner` - :type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository` - :param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner` - :type server: class: `fedn.network.combiner.Combiner` - :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` - :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` - :param control: A handle to the :class: `fedn.network.combiner.round.RoundController` - :type control: class: `fedn.network.combiner.round.RoundController` - """ + """ Initialize the aggregator.""" self.name = self.__class__.__name__ self.storage = storage self.server = server diff --git a/fedn/fedn/network/combiner/connect.py b/fedn/fedn/network/combiner/connect.py new file mode 100644 index 000000000..de705a56c --- /dev/null +++ b/fedn/fedn/network/combiner/connect.py @@ -0,0 +1,125 @@ +# This file contains the Connector class for announcing combiner to the FEDn network via the discovery service (REST-API). +# The Connector class is used by the Combiner class in fedn/network/combiner/server.py. +# Once announced, the combiner will be able to receive controller requests from the controllerStub via gRPC. +# The discovery service will also add the combiner to the statestore. +# +# +import enum + +import requests + + +class Status(enum.Enum): + """ Enum for representing the status of a combiner announcement.""" + Unassigned = 0 + Assigned = 1 + TryAgain = 2 + UnAuthorized = 3 + UnMatchedConfig = 4 + + +class ConnectorCombiner: + """ Connector for annnouncing combiner to the FEDn network. + + :param host: host of discovery service + :type host: str + :param port: port of discovery service + :type port: int + :param myhost: host of combiner + :type myhost: str + :param fqdn: fully qualified domain name of combiner + :type fqdn: str + :param myport: port of combiner + :type myport: int + :param token: token for authentication + :type token: str + :param name: name of combiner + :type name: str + :param secure: True if https is used, False if http + :type secure: bool + :param verify: True if certificate is verified, False if not + :type verify: bool + """ + + def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False, verify=False): + """ Initialize the ConnectorCombiner. + + :param host: The host of the discovery service. + :type host: str + :param port: The port of the discovery service. + :type port: int + :param myhost: The host of the combiner. + :type myhost: str + :param fqdn: The fully qualified domain name of the combiner. + :type fqdn: str + :param myport: The port of the combiner. + :type myport: int + :param token: The token for the discovery service. + :type token: str + :param name: The name of the combiner. + :type name: str + :param secure: Use https for the connection to the discovery service. + :type secure: bool + :param verify: Verify the connection to the discovery service. + :type verify: bool + """ + + self.host = host + self.fqdn = fqdn + self.port = port + self.myhost = myhost + self.myport = myport + self.token = token + self.name = name + self.secure = secure + self.verify = verify + + # for https we assume a an ingress handles permanent redirect (308) + self.prefix = "http://" + if port: + self.connect_string = "{}{}:{}".format( + self.prefix, self.host, self.port) + else: + self.connect_string = "{}{}".format( + self.prefix, self.host) + + print("\n\nsetting the connection string to {}\n\n".format( + self.connect_string), flush=True) + + def announce(self): + """ + Announce combiner to FEDn network via discovery service (REST-API). + + :return: Tuple with announcement Status, FEDn network configuration if sucessful, else None. + :rtype: :class:`fedn.network.combiner.connect.Status`, str + """ + payload = { + "combiner_id": self.name, + "address": self.myhost, + "fqdn": self.fqdn, + "port": self.myport, + "secure_grpc": self.secure + } + try: + retval = requests.post(self.connect_string + '/add_combiner', json=payload, + verify=self.verify, + headers={'Authorization': 'Token {}'.format(self.token)}) + except Exception: + return Status.Unassigned, {} + + if retval.status_code == 400: + # Get error messange from response + reason = retval.json()['message'] + return Status.UnMatchedConfig, reason + + if retval.status_code == 401: + reason = "Unauthorized connection to reducer, make sure the correct token is set" + return Status.UnAuthorized, reason + + if retval.status_code >= 200 and retval.status_code < 204: + if retval.json()['status'] == 'retry': + reason = retval.json()['message'] + return Status.TryAgain, reason + return Status.Assigned, retval.json() + + return Status.Unassigned, None diff --git a/fedn/fedn/network/combiner/interfaces.py b/fedn/fedn/network/combiner/interfaces.py index 832c5fdfa..6dfb0428d 100644 --- a/fedn/fedn/network/combiner/interfaces.py +++ b/fedn/fedn/network/combiner/interfaces.py @@ -4,6 +4,7 @@ from io import BytesIO import grpc +from google.protobuf.json_format import MessageToJson import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc @@ -14,7 +15,15 @@ class CombinerUnavailableError(Exception): class Channel: - """ Wrapper for a gRPC channel. """ + """ Wrapper for a gRPC channel. + + :param address: The address for the gRPC server. + :type address: str + :param port: The port for connecting to the gRPC server. + :type port: int + :param certificate: The certificate for connecting to the gRPC server (optional) + :type certificate: str + """ def __init__(self, address, port, certificate=None): """ Create a channel. @@ -51,33 +60,31 @@ def get_channel(self): class CombinerInterface: - """ Interface for the Combiner (server). + """ Interface for the Combiner (aggregation server). Abstraction on top of the gRPC server servicer. + :param parent: The parent combiner (controller) + :type parent: :class:`fedn.network.api.interfaces.API` + :param name: The name of the combiner. + :type name: str + :param address: The address of the combiner. + :type address: str + :param fqdn: The fully qualified domain name of the combiner. + :type fqdn: str + :param port: The port of the combiner. + :type port: int + :param certificate: The certificate of the combiner (optional). + :type certificate: str + :param key: The key of the combiner (optional). + :type key: str + :param ip: The ip of the combiner (optional). + :type ip: str + :param config: The configuration of the combiner (optional). + :type config: dict """ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): - """ Initialize the combiner interface. - - :parameter parent: The parent combiner. - :type parent: :class:`fedn.network.combiner.Combiner` - :parameter name: The name of the combiner. - :type name: str - :parameter address: The address of the combiner. - :type address: str - :parameter fqdn: The fully qualified domain name of the combiner. - :type fqdn: str - :parameter port: The port of the combiner. - :type port: int - :parameter certificate: The certificate of the combiner (optional). - :type certificate: str - :parameter key: The key of the combiner (optional). - :type key: str - :parameter ip: The ip of the combiner (optional). - :type ip: str - :parameter config: The configuration of the combiner (optional). - :type config: dict - """ + """ Initialize the combiner interface.""" self.parent = parent self.name = name self.address = address @@ -108,12 +115,12 @@ def from_json(combiner_config): def to_dict(self): """ Export combiner configuration to a dictionary. - : return: A dictionary with the combiner configuration. - : rtype: dict + :return: A dictionary with the combiner configuration. + :rtype: dict """ data = { - 'parent': self.parent.to_dict(), + 'parent': self.parent, 'name': self.name, 'address': self.address, 'fqdn': self.fqdn, @@ -174,7 +181,6 @@ def report(self): :return: A dictionary describing the combiner state. :rtype: dict - :raises CombinerUnavailableError: If the combiner is unavailable. """ channel = Channel(self.address, self.port, @@ -194,8 +200,7 @@ def report(self): raise def configure(self, config=None): - """ Configure the combiner. - Set the parameters in config at the server. + """ Configure the combiner. Set the parameters in config at the server. :param config: A dictionary containing parameters. :type config: dict @@ -220,6 +225,23 @@ def configure(self, config=None): else: raise + def flush_model_update_queue(self): + """ Reset the model update queue on the combiner. """ + + channel = Channel(self.address, self.port, + self.certificate).get_channel() + control = rpc.ControlStub(channel) + + request = fedn.ControlRequest() + + try: + control.FlushAggregationQueue(request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + raise CombinerUnavailableError + else: + raise + def submit(self, config): """ Submit a compute plan to the combiner. @@ -251,7 +273,6 @@ def submit(self, config): def get_model(self, id): """ Download a model from the combiner server. - :param id: The model id. :type id: str :return: A file-like object containing the model. @@ -300,3 +321,22 @@ def allowing_clients(self): return False return False + + def list_active_clients(self): + """ List active clients. + + :return: A list of active clients. + :rtype: json + """ + channel = Channel(self.address, self.port, + self.certificate).get_channel() + control = rpc.ConnectorStub(channel) + request = fedn.ListClientsRequest() + try: + response = control.ListActiveClients(request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + raise CombinerUnavailableError + else: + raise + return MessageToJson(response) diff --git a/fedn/fedn/network/combiner/round.py b/fedn/fedn/network/combiner/round.py index 6ae7a92cc..dd41deee3 100644 --- a/fedn/fedn/network/combiner/round.py +++ b/fedn/fedn/network/combiner/round.py @@ -29,6 +29,7 @@ class RoundController: """ def __init__(self, aggregator_name, storage, server, modelservice): + """ Initialize the RoundController.""" self.round_configs = queue.Queue() self.storage = storage @@ -118,7 +119,6 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): """ time_window = float(config['round_timeout']) - # buffer_size = int(config['buffer_size']) tt = 0.0 while tt < time_window: @@ -150,8 +150,14 @@ def _training_round(self, config, clients): # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) + # If buffer_size is -1 (default), the round terminates when/if all clients have completed. + if int(config['buffer_size']) == -1: + buffer_size = len(clients) + else: + buffer_size = int(config['buffer_size']) + # Wait / block until the round termination policy has been met. - self.waitforit(config, buffer_size=len(clients)) + self.waitforit(config, buffer_size=buffer_size) tic = time.time() model = None @@ -159,7 +165,6 @@ def _training_round(self, config, clients): try: helper = get_helper(config['helper_type']) - # print config delete_models_storage print("ROUNDCONTROL: Config delete_models_storage: {}".format(config['delete_models_storage']), flush=True) if config['delete_models_storage'] == 'True': delete_models = True @@ -199,9 +204,9 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): # If the model is already in memory at the server we do not need to do anything. if self.modelservice.models.exist(model_id): - print("MODEL EXISTST (NOT)", flush=True) + print("ROUNDCONTROL: Model already exists in memory, skipping model staging.", flush=True) return - print("MODEL STAGING", flush=True) + print("ROUNDCONTROL: Model Staging, fetching model from storage...", flush=True) # If not, download it and stage it in memory at the combiner. tries = 0 while True: diff --git a/fedn/fedn/network/combiner/server.py b/fedn/fedn/network/combiner/server.py index ff0134d6b..7a9c87ff9 100644 --- a/fedn/fedn/network/combiner/server.py +++ b/fedn/fedn/network/combiner/server.py @@ -12,10 +12,10 @@ import fedn.common.net.grpc.fedn_pb2 as fedn import fedn.common.net.grpc.fedn_pb2_grpc as rpc -from fedn.common.net.connect import ConnectorCombiner, Status from fedn.common.net.grpc.server import Server from fedn.common.storage.s3.s3repo import S3ModelRepository from fedn.common.tracer.mongotracer import MongoTracer +from fedn.network.combiner.connect import ConnectorCombiner, Status from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.round import RoundController @@ -49,14 +49,14 @@ def role_to_proto_role(role): class Combiner(rpc.CombinerServicer, rpc.ReducerServicer, rpc.ConnectorServicer, rpc.ControlServicer): - """ Combiner gRPC server. """ + """ Combiner gRPC server. - def __init__(self, config): - """ Initialize a Combiner. + :param config: configuration for the combiner + :type config: dict + """ - :param config: configuration for the combiner - :type config: dict - """ + def __init__(self, config): + """ Initialize Combiner server.""" # Client queues self.clients = {} @@ -98,7 +98,12 @@ def __init__(self, config): break if status == Status.UnAuthorized: print(response, flush=True) + print("Status.UnAuthorized", flush=True) sys.exit("Exiting: Unauthorized") + if status == Status.UnMatchedConfig: + print(response, flush=True) + print("Status.UnMatchedConfig", flush=True) + sys.exit("Exiting: Missing config") cert = announce_config['certificate'] key = announce_config['key'] @@ -371,6 +376,19 @@ def __register_heartbeat(self, client): self.__join_client(client) self.clients[client.name]["lastseen"] = datetime.now() + def flush_model_update_queue(self): + """Clear the model update queue (aggregator). """ + + q = self.control.aggregator.model_updates + try: + with q.mutex: + q.queue.clear() + q.all_tasks_done.notify_all() + q.unfinished_tasks = 0 + return True + except Exception: + return False + ##################################################################################################################### # Control Service @@ -400,6 +418,9 @@ def Start(self, control: fedn.ControlRequest, context): return response + # RPCs related to remote configuration of the server, round controller, + # aggregator and their states. + def Configure(self, control: fedn.ControlRequest, context): """ Configure the Combiner. @@ -416,6 +437,29 @@ def Configure(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() return response + def FlushAggregationQueue(self, control: fedn.ControlRequest, context): + """ Flush the queue. + + :param control: the control request + :type control: :class:`fedn.common.net.grpc.fedn_pb2.ControlRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the control response + :rtype: :class:`fedn.common.net.grpc.fedn_pb2.ControlResponse` + """ + + status = self.flush_model_update_queue() + + response = fedn.ControlResponse() + if status: + response.message = 'Success' + else: + response.message = 'Failed' + + return response + + ############################################################################## + def Stop(self, control: fedn.ControlRequest, context): """ TODO: Not yet implemented. @@ -494,25 +538,6 @@ def Report(self, control: fedn.ControlRequest, context): ##################################################################################################################### - def AllianceStatusStream(self, response, context): - """ A server stream RPC endpoint that emits status messages. - - :param response: the response - :type response: :class:`fedn.common.net.grpc.fedn_pb2.Response` - :param context: the context (unused) - :type context: :class:`grpc._server._Context`""" - status = fedn.Status( - status="Client {} connecting to AllianceStatusStream.".format(response.sender)) - status.log_level = fedn.Status.INFO - status.sender.name = self.id - status.sender.role = role_to_proto_role(self.role) - self._subscribe_client_to_queue(response.sender, fedn.Channel.STATUS) - q = self.__get_queue(response.sender, fedn.Channel.STATUS) - self._send_status(status) - - while True: - yield q.get() - def SendStatus(self, status: fedn.Status, context): """ A client stream RPC endpoint that accepts status messages. @@ -689,12 +714,16 @@ def ModelUpdateRequestStream(self, response, context): self._send_status(status) + self.tracer.update_client_status(client.name, "online") + while context.is_active(): try: yield q.get(timeout=1.0) except queue.Empty: pass + self.tracer.update_client_status(client.name, "offline") + def ModelValidationStream(self, update, context): """ Model validation stream RPC endpoint. Update status for client is connecting to stream. diff --git a/fedn/fedn/network/config.py b/fedn/fedn/network/config.py index 0ccfddc66..a9e8773f4 100644 --- a/fedn/fedn/network/config.py +++ b/fedn/fedn/network/config.py @@ -6,9 +6,7 @@ class Config(ABC): class ReducerConfig(Config): - """ - - """ + """ Configuration for the Reducer component. """ compute_bundle_dir = None models_dir = None diff --git a/fedn/fedn/network/controller/__init__.py b/fedn/fedn/network/controller/__init__.py index e69de29bb..76372d276 100644 --- a/fedn/fedn/network/controller/__init__.py +++ b/fedn/fedn/network/controller/__init__.py @@ -0,0 +1,3 @@ +""" The controller package is responsible for orchestrating the federated learning process. It's acts as a gRPC client and sends round config tasks +to the :class:`fedn.network.combiner.Combiner`. """ +# flake8: noqa diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index 7929f2892..a8e32333d 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -1,4 +1,5 @@ import copy +import datetime import time import uuid @@ -8,10 +9,10 @@ class UnsupportedStorageBackend(Exception): - """ Exception class for when storage backend is not supported. Passes """ + """Exception class for when storage backend is not supported. Passes""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -22,49 +23,46 @@ def __init__(self, message): class MisconfiguredStorageBackend(Exception): - """ Exception class for when storage backend is misconfigured. """ + """Exception class for when storage backend is misconfigured. - def __init__(self, message): - """ Constructor method. - - :param message: The exception message. - :type message: str + :param message: The exception message. + :type message: str + """ - """ + def __init__(self, message): + """Constructor method.""" self.message = message super().__init__(self.message) class NoModelException(Exception): - """ Exception class for when model is None """ + """Exception class for when model is None - def __init__(self, message): - """ Constructor method. + :param message: The exception message. + :type message: str + """ - :param message: The exception message. - :type message: str - - """ + def __init__(self, message): + """Constructor method.""" self.message = message super().__init__(self.message) class Control(ControlBase): - """ Controller, implementing the overall global training, validation and inference logic. """ + """Controller, implementing the overall global training, validation and inference logic. - def __init__(self, statestore): - """ Constructor method. + :param statestore: A StateStorage instance. + :type statestore: class: `fedn.network.statestorebase.StateStorageBase` + """ - :param statestore: A StateStorage instance. - :type statestore: class: `fedn.common.storage.statestorage.StateStorage` - - """ + def __init__(self, statestore): + """Constructor method.""" super().__init__(statestore) self.name = "DefaultControl" def session(self, config): - """ Execute a new training session. A session consists of one + """Execute a new training session. A session consists of one or several global rounds. All rounds in the same session have the same round_config. @@ -74,26 +72,41 @@ def session(self, config): """ if self._state == ReducerState.instructing: - print("Controller already in INSTRUCTING state. A session is in progress.", flush=True) + print( + "Controller already in INSTRUCTING state. A session is in progress.", + flush=True, + ) + return + + if not self.statestore.get_latest_model(): + print("No model in model chain, please provide a seed model!") return self._state = ReducerState.instructing # Must be called to set info in the db + config["committed_at"] = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) self.new_session(config) - if not self.get_latest_model(): - print("No model in model chain, please provide a seed model!") - + if not self.statestore.get_latest_model(): + print( + "No model in model chain, please provide a seed model!", + flush=True, + ) self._state = ReducerState.monitoring last_round = int(self.get_latest_round_id()) + # Clear potential stragglers/old model updates at combiners + for combiner in self.network.get_combiners(): + combiner.flush_model_update_queue() + # Execute the rounds in this session - for round in range(1, int(config['rounds'] + 1)): + for round in range(1, int(config["rounds"] + 1)): # Increment the round number - # round_id = self.new_round(session['session_id']) if last_round: current_round = last_round + round else: @@ -102,10 +115,17 @@ def session(self, config): try: _, round_data = self.round(config, current_round) except TypeError as e: - print("Could not unpack data from round: {0}".format(e), flush=True) - - print("CONTROL: Round completed with status {}".format( - round_data['status']), flush=True) + print( + "Could not unpack data from round: {0}".format(e), + flush=True, + ) + + print( + "CONTROL: Round completed with status {}".format( + round_data["status"] + ), + flush=True, + ) self.tracer.set_round_data(round_data) @@ -113,44 +133,50 @@ def session(self, config): self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute a single global round. + """Execute a single global round. :param session_config: The session config. :type session_config: dict :param round_id: The round id. :type round_id: str(int) - """ - round_data = {'round_id': round_id} + round_data = {"round_id": round_id} if len(self.network.get_combiners()) < 1: print("REDUCER: No combiners connected!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data # 1. Assemble round config for this global round, # and check which combiners are able to participate # in the round. round_config = copy.deepcopy(session_config) - round_config['rounds'] = 1 - round_config['round_id'] = round_id - round_config['task'] = 'training' - round_config['model_id'] = self.get_latest_model() - round_config['helper_type'] = self.statestore.get_helper() + round_config["rounds"] = 1 + round_config["round_id"] = round_id + round_config["task"] = "training" + round_config["model_id"] = self.statestore.get_latest_model() + round_config["helper_type"] = self.statestore.get_helper() combiners = self.get_participating_combiners(round_config) round_start = self.evaluate_round_start_policy(combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) + round_data["status"] = "Failed" return None - round_data['round_config'] = round_config + round_data["round_config"] = round_config # 2. Ask participating combiners to coordinate model updates _ = self.request_model_updates(combiners) @@ -160,27 +186,37 @@ def round(self, session_config, round_id): # dict to store combiners that have successfully produced an updated model updated = {} # wait until all combiners have produced an updated model or until round timeout - print("CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id), flush=True) + print( + "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( + round_id=round_id + ), + flush=True, + ) while len(updated) < len(combiners): round = self.statestore.get_round(round_id) if round: print("CONTROL: Round found!", flush=True) # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round['combiners']: + for combiner in round["combiners"]: print(combiner, flush=True) - if combiner['status'] == 'Success': - if combiner['name'] not in updated.keys(): + if combiner["status"] == "Success": + if combiner["name"] not in updated.keys(): # Add combiner to updated dict - updated[combiner['name']] = combiner['model_id'] + updated[combiner["name"]] = combiner["model_id"] # Print combiner status - print("CONTROL: Combiner {name} status: {status}".format( - name=combiner['name'], status=combiner['status']), flush=True) + print( + "CONTROL: Combiner {name} status: {status}".format( + name=combiner["name"], status=combiner["status"] + ), + flush=True, + ) else: # Print every 10 seconds based on value of wait if wait % 10 == 0: - print("CONTROL: Round not found! Waiting...", flush=True) - if wait >= session_config['round_timeout']: + print( + "CONTROL: Waiting for round to complete...", flush=True + ) + if wait >= session_config["round_timeout"]: print("CONTROL: Round timeout! Exiting round...", flush=True) break # Update wait time used for timeout @@ -190,53 +226,77 @@ def round(self, session_config, round_id): round_valid = self.evaluate_round_validity_policy(updated) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data print("CONTROL: Reducing models from combiners...", flush=True) # 3. Reduce combiner models into a global model try: model, data = self.reduce(updated) - round_data['reduce'] = data + round_data["reduce"] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print("CONTROL: Failed to reduce models from combiners: {}".format( - e), flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Failed to reduce models from combiners: {}".format( + e + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data # 6. Commit the global model to model trail if model is not None: - print("CONTROL: Committing global model to model trail...", flush=True) + print( + "CONTROL: Committing global model to model trail...", + flush=True, + ) tic = time.time() model_id = uuid.uuid4() - self.commit(model_id, model) - round_data['time_commit'] = time.time() - tic - print("CONTROL: Done committing global model to model trail!", flush=True) + session_id = ( + session_config["session_id"] + if "session_id" in session_config + else None + ) + self.commit(model_id, model, session_id) + round_data["time_commit"] = time.time() - tic + print( + "CONTROL: Done committing global model to model trail!", + flush=True, + ) else: - print("REDUCER: failed to update model in round with config {}".format( - session_config), flush=True) - round_data['status'] = 'Failed' + print( + "REDUCER: failed to update model in round with config {}".format( + session_config + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data - round_data['status'] = 'Success' + round_data["status"] = "Success" # 4. Trigger participating combiner nodes to execute a validation round for the current model - validate = session_config['validate'] + validate = session_config["validate"] if validate: combiner_config = copy.deepcopy(session_config) - combiner_config['round_id'] = round_id - combiner_config['model_id'] = self.get_latest_model() - combiner_config['task'] = 'validation' - combiner_config['helper_type'] = self.statestore.get_helper() + combiner_config["round_id"] = round_id + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "validation" + combiner_config["helper_type"] = self.statestore.get_helper() validating_combiners = self._select_participating_combiners( - combiner_config) + combiner_config + ) for combiner, combiner_config in validating_combiners: try: - print("CONTROL: Submitting validation round to combiner {}".format( - combiner), flush=True) + print( + "CONTROL: Submitting validation round to combiner {}".format( + combiner + ), + flush=True, + ) combiner.submit(combiner_config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) @@ -245,16 +305,16 @@ def round(self, session_config, round_id): return model_id, round_data def reduce(self, combiners): - """ Combine updated models from Combiner nodes into one global model. + """Combine updated models from Combiner nodes into one global model. :param combiners: dict of combiner names (key) and model IDs (value) to reduce :type combiners: dict """ meta = {} - meta['time_fetch_model'] = 0.0 - meta['time_load_model'] = 0.0 - meta['time_aggregate_model'] = 0.0 + meta["time_fetch_model"] = 0.0 + meta["time_load_model"] = 0.0 + meta["time_aggregate_model"] = 0.0 i = 1 model = None @@ -264,18 +324,25 @@ def reduce(self, combiners): return model, meta for name, model_id in combiners.items(): - # TODO: Handle inactive RPC error in get_model and raise specific error - print("REDUCER: Fetching model ({model_id}) from combiner {name}".format( - model_id=model_id, name=name), flush=True) + print( + "REDUCER: Fetching model ({model_id}) from combiner {name}".format( + model_id=model_id, name=name + ), + flush=True, + ) try: tic = time.time() combiner = self.get_combiner(name) data = combiner.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: - print("REDUCER: Failed to fetch model from combiner {}: {}".format( - name, e), flush=True) + print( + "REDUCER: Failed to fetch model from combiner {}: {}".format( + name, e + ), + flush=True, + ) data = None if data is not None: @@ -284,21 +351,21 @@ def reduce(self, combiners): helper = self.get_helper() data.seek(0) model_next = helper.load(data) - meta['time_load_model'] += (time.time() - tic) + meta["time_load_model"] += time.time() - tic tic = time.time() model = helper.increment_average(model, model_next, i, i) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() data.seek(0) model = helper.load(data) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic i = i + 1 return model, meta def infer_instruct(self, config): - """ Main entrypoint for executing the inference compute plan. + """Main entrypoint for executing the inference compute plan. :param config: configuration for the inference round """ @@ -310,7 +377,7 @@ def infer_instruct(self, config): self.__state = ReducerState.instructing # Check for a model chain - if not self.get_latest_model(): + if not self.statestore.latest_model(): print("No model in model chain, please seed the alliance!") # Set reducer in monitoring state @@ -326,7 +393,7 @@ def infer_instruct(self, config): self.__state = ReducerState.idle def inference_round(self, config): - """ Execute an inference round. + """Execute an inference round. :param config: configuration for the inference round """ @@ -341,21 +408,27 @@ def inference_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) - combiner_config['model_id'] = self.get_latest_model() - combiner_config['task'] = 'inference' - combiner_config['helper_type'] = self.statestore.get_framework() + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "inference" + combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners( - combiner_config) + validating_combiners = self._select_round_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - validating_combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + validating_combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) return None # Synch combiners with latest model and trigger inference diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index 471e47a6a..077620c14 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -1,14 +1,18 @@ import os import uuid from abc import ABC, abstractmethod +from time import sleep import fedn.utils.helpers from fedn.common.storage.s3.s3repo import S3ModelRepository from fedn.common.tracer.mongotracer import MongoTracer +from fedn.network.api.network import Network from fedn.network.combiner.interfaces import CombinerUnavailableError -from fedn.network.network import Network from fedn.network.state import ReducerState +# Maximum number of tries to connect to statestore and retrieve storage configuration +MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) + class UnsupportedStorageBackend(Exception): pass @@ -23,13 +27,16 @@ class MisconfiguredHelper(Exception): class ControlBase(ABC): - """ Base class and interface for a global controller. + """Base class and interface for a global controller. Override this class to implement a global training strategy (control). + + :param statestore: The statestore object. + :type statestore: :class:`fedn.network.statestore.statestorebase.StateStoreBase` """ @abstractmethod def __init__(self, statestore): - """ """ + """Constructor.""" self._state = ReducerState.setup self.statestore = statestore @@ -37,22 +44,44 @@ def __init__(self, statestore): self.network = Network(self, statestore) try: - storage_config = self.statestore.get_storage_backend() + not_ready = True + tries = 0 + while not_ready: + storage_config = self.statestore.get_storage_backend() + if storage_config: + not_ready = False + else: + print( + "REDUCER CONTROL: Storage backend not configured, waiting...", + flush=True, + ) + sleep(5) + tries += 1 + if tries > MAX_TRIES_BACKEND: + raise Exception except Exception: print( - "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) + "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", + flush=True, + ) raise MisconfiguredStorageBackend() - if storage_config['storage_type'] == 'S3': - self.model_repository = S3ModelRepository(storage_config['storage_config']) + if storage_config["storage_type"] == "S3": + self.model_repository = S3ModelRepository( + storage_config["storage_config"] + ) else: - print("REDUCER CONTROL: Unsupported storage backend, exiting.", flush=True) + print( + "REDUCER CONTROL: Unsupported storage backend, exiting.", + flush=True, + ) raise UnsupportedStorageBackend() # The tracer is a helper that manages state in the database backend statestore_config = statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], statestore_config["network_id"] + ) if self.statestore.is_inited(): self._state = ReducerState.idle @@ -70,54 +99,48 @@ def reduce(self, combiners): pass def get_helper(self): - """ Get a helper instance from global config. + """Get a helper instance from global config. :return: Helper instance. + :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` """ helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) + raise MisconfiguredHelper( + "Unsupported helper type {}, please configure compute_package.helper !".format( + helper_type + ) + ) return helper def get_state(self): - """ + """Get the current state of the controller. - :return: + :return: The current state. + :rtype: :class:`fedn.network.state.ReducerState` """ return self._state def idle(self): - """ + """Check if the controller is idle. - :return: + :return: True if idle, False otherwise. + :rtype: bool """ if self._state == ReducerState.idle: return True else: return False - def get_first_model(self): - """ - - :return: - """ - return self.statestore.get_first() - - def get_latest_model(self): - """ - - :return: - """ - return self.statestore.get_latest() - def get_model_info(self): """ :return: """ - return self.statestore.get_model_info() + return self.statestore.get_model_trail() + # TODO: remove use statestore.get_events() instead def get_events(self): """ @@ -130,7 +153,7 @@ def get_latest_round_id(self): if not last_round: return 0 else: - return last_round['round_id'] + return last_round["round_id"] def get_latest_round(self): round = self.statestore.get_latest_round() @@ -144,27 +167,29 @@ def get_compute_package_name(self): definition = self.statestore.get_compute_package() if definition: try: - package_name = definition['filename'] + package_name = definition["filename"] return package_name except (IndexError, KeyError): print( - "No context filename set for compute context definition", flush=True) + "No context filename set for compute context definition", + flush=True, + ) return None else: return None def set_compute_package(self, filename, path): - """ Persist the configuration for the compute package. """ + """Persist the configuration for the compute package.""" self.model_repository.set_compute_package(filename, path) self.statestore.set_compute_package(filename) - def get_compute_package(self, compute_package=''): + def get_compute_package(self, compute_package=""): """ :param compute_package: :return: """ - if compute_package == '': + if compute_package == "": compute_package = self.get_compute_package_name() if compute_package: return self.model_repository.get_compute_package(compute_package) @@ -172,40 +197,49 @@ def get_compute_package(self, compute_package=''): return None def new_session(self, config): - """ Initialize a new session in backend db. """ + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() - config['session_id'] = str(session_id) + config["session_id"] = str(session_id) + else: + session_id = config["session_id"] self.tracer.new_session(id=session_id) self.tracer.set_session_config(session_id, config) def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update. """ + """Call Combiner server RPC to get a model update.""" cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) cl.append((combiner, response)) return cl - def commit(self, model_id, model=None): - """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """ + def commit(self, model_id, model=None, session_id=None): + """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" helper = self.get_helper() if model is not None: - print("CONTROL: Saving model file temporarily to disk...", flush=True) + print( + "CONTROL: Saving model file temporarily to disk...", flush=True + ) outfile_name = helper.save(model) print("CONTROL: Uploading model to Minio...", flush=True) model_id = self.model_repository.set_model( - outfile_name, is_file=True) + outfile_name, is_file=True + ) print("CONTROL: Deleting temporary model file...", flush=True) os.unlink(outfile_name) - print("CONTROL: Committing model {} to global model trail in statestore...".format( - model_id), flush=True) - self.statestore.set_latest(model_id) + print( + "CONTROL: Committing model {} to global model trail in statestore...".format( + model_id + ), + flush=True, + ) + self.statestore.set_latest_model(model_id, session_id) def get_combiner(self, name): for combiner in self.network.get_combiners(): @@ -215,7 +249,7 @@ def get_combiner(self, name): def get_participating_combiners(self, combiner_round_config): """Assemble a list of combiners able to participate in a round as - descibed by combiner_round_config. + descibed by combiner_round_config. """ combiners = [] for combiner in self.network.get_combiners(): @@ -227,45 +261,47 @@ def get_participating_combiners(self, combiner_round_config): if combiner_state is not None: is_participating = self.evaluate_round_participation_policy( - combiner_round_config, combiner_state) + combiner_round_config, combiner_state + ) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy(self, compute_plan, combiner_state): - """ Evaluate policy for combiner round-participation. - A combiner participates if it is responsive and reports enough - active clients to participate in the round. + def evaluate_round_participation_policy( + self, compute_plan, combiner_state + ): + """Evaluate policy for combiner round-participation. + A combiner participates if it is responsive and reports enough + active clients to participate in the round. """ - if compute_plan['task'] == 'training': - nr_active_clients = int(combiner_state['nr_active_trainers']) - elif compute_plan['task'] == 'validation': - nr_active_clients = int(combiner_state['nr_active_validators']) + if compute_plan["task"] == "training": + nr_active_clients = int(combiner_state["nr_active_trainers"]) + elif compute_plan["task"] == "validation": + nr_active_clients = int(combiner_state["nr_active_validators"]) else: print("Invalid task type!", flush=True) return False - if int(compute_plan['clients_required']) <= nr_active_clients: + if int(compute_plan["clients_required"]) <= nr_active_clients: return True else: return False def evaluate_round_start_policy(self, combiners): - """ Check if the policy to start a round is met. """ + """Check if the policy to start a round is met.""" if len(combiners) > 0: - return True else: return False def evaluate_round_validity_policy(self, combiners): - """ Check if the round should be seen as valid. + """Check if the round should be seen as valid. - At the end of the round, before committing a model to the global model trail, - we check if the round validity policy has been met. This can involve - e.g. asserting that a certain number of combiners have reported in an - updated model, or that criteria on model performance have been met. + At the end of the round, before committing a model to the global model trail, + we check if the round validity policy has been met. This can involve + e.g. asserting that a certain number of combiners have reported in an + updated model, or that criteria on model performance have been met. """ if combiners.keys() == []: return False @@ -283,7 +319,8 @@ def _select_participating_combiners(self, compute_plan): if combiner_state: is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state) + compute_plan, combiner_state + ) if is_participating: participating_combiners.append((combiner, compute_plan)) return participating_combiners diff --git a/fedn/fedn/network/dashboard/restservice.py b/fedn/fedn/network/dashboard/restservice.py index b951fb83d..14e7266bb 100644 --- a/fedn/fedn/network/dashboard/restservice.py +++ b/fedn/fedn/network/dashboard/restservice.py @@ -23,8 +23,8 @@ from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha -UPLOAD_FOLDER = '/app/client/package/' -ALLOWED_EXTENSIONS = {'gz', 'bz2', 'tar', 'zip', 'tgz'} +UPLOAD_FOLDER = "/app/client/package/" +ALLOWED_EXTENSIONS = {"gz", "bz2", "tar", "zip", "tgz"} def allowed_file(filename): @@ -33,8 +33,10 @@ def allowed_file(filename): :param filename: :return: """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) def encode_auth_token(secret_key): @@ -43,16 +45,17 @@ def encode_auth_token(secret_key): """ try: payload = { - 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=90, seconds=0), - 'iat': datetime.datetime.utcnow(), - 'status': 'Success' + "exp": datetime.datetime.utcnow() + + datetime.timedelta(days=90, seconds=0), + "iat": datetime.datetime.utcnow(), + "status": "Success", } - token = jwt.encode( - payload, - secret_key, - algorithm='HS256' + token = jwt.encode(payload, secret_key, algorithm="HS256") + print( + "\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n".format( + token + ) ) - print('\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n'.format(token)) return token except Exception as e: return e @@ -64,58 +67,52 @@ def decode_auth_token(auth_token, secret): :return: string """ try: - payload = jwt.decode( - auth_token, - secret, - algorithms=['HS256'] - ) + payload = jwt.decode(auth_token, secret, algorithms=["HS256"]) return payload["status"] except jwt.ExpiredSignatureError as e: print(e) - return 'Token has expired.' + return "Token has expired." except jwt.InvalidTokenError as e: print(e) - return 'Invalid token.' + return "Invalid token." class ReducerRestService: - """ - - """ - - def __init__(self, config, control, certificate_manager): + """ """ + def __init__(self, config, control, statestore, certificate_manager): print("config object!: \n\n\n\n{}".format(config)) - if config['host']: - self.host = config['host'] + if config["host"]: + self.host = config["host"] else: self.host = None - self.name = config['name'] + self.name = config["name"] - self.port = config['port'] - self.network_id = config['name'] + '-network' + self.port = config["port"] + self.network_id = config["name"] + "-network" - if 'token' in config.keys(): + if "token" in config.keys(): self.token_auth_enabled = True else: self.token_auth_enabled = False - if 'secret_key' in config.keys(): - self.SECRET_KEY = config['secret_key'] + if "secret_key" in config.keys(): + self.SECRET_KEY = config["secret_key"] else: self.SECRET_KEY = None - if 'use_ssl' in config.keys(): - self.use_ssl = config['use_ssl'] + if "use_ssl" in config.keys(): + self.use_ssl = config["use_ssl"] self.remote_compute_package = config["remote_compute_package"] if self.remote_compute_package: - self.package = 'remote' + self.package = "remote" else: - self.package = 'local' + self.package = "local" self.control = control + self.statestore = statestore self.certificate_manager = certificate_manager self.current_compute_context = None @@ -124,9 +121,7 @@ def to_dict(self): :return: """ - data = { - 'name': self.name - } + data = {"name": self.name} return data def check_compute_package(self): @@ -151,7 +146,7 @@ def check_initial_model(self): :rtype: bool """ - if self.control.get_latest_model(): + if self.statestore.get_latest_model(): return True else: return False @@ -164,24 +159,40 @@ def check_configured_response(self): :rtype: json """ if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Controller is not configured."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Controller is not configured.", + } + ) if not self.check_compute_package(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Compute package is not configured. Please upload the compute package."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Compute package is not configured. Please upload the compute package.", + } + ) if not self.check_initial_model(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Initial model is not configured. Please upload the model."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Initial model is not configured. Please upload the model.", + } + ) if not self.control.idle(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Conroller is not in idle state, try again later. "}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Conroller is not in idle state, try again later. ", + } + ) return None def check_configured(self): @@ -191,17 +202,29 @@ def check_configured(self): :return: Rendered html template or None """ if not self.check_compute_package(): - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=False, - message='Please set the compute package') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=False, + message="Please set the compute package", + ) if self.control.state() == ReducerState.setup: - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=True, - message='Warning. Reducer is not base-configured. please do so with config file.') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=True, + message="Warning. Reducer is not base-configured. please do so with config file.", + ) if not self.check_initial_model(): - return render_template('setup_model.html', message="Please set the initial model.") + return render_template( + "setup_model.html", message="Please set the initial model." + ) return None @@ -215,31 +238,37 @@ def authorize(self, r, secret): """ try: # Get token - if 'Authorization' in r.headers: # header auth - request_token = r.headers.get('Authorization').split()[1] - elif 'token' in r.args: # args auth - request_token = str(r.args.get('token')) - elif 'fedn_token' in r.cookies: - request_token = r.cookies.get('fedn_token') + if "Authorization" in r.headers: # header auth + request_token = r.headers.get("Authorization").split()[1] + elif "token" in r.args: # args auth + request_token = str(r.args.get("token")) + elif "fedn_token" in r.cookies: + request_token = r.cookies.get("fedn_token") else: # no token provided - print('Authorization failed. No token provided.', flush=True) + print("Authorization failed. No token provided.", flush=True) abort(401) # Log token and secret print( - f'Secret: {secret}. Request token: {request_token}.', flush=True) + f"Secret: {secret}. Request token: {request_token}.", + flush=True, + ) # Authenticate status = decode_auth_token(request_token, secret) - if status == 'Success': + if status == "Success": return True else: - print('Authorization failed. Status: "{}"'.format( - status), flush=True) + print( + 'Authorization failed. Status: "{}"'.format(status), + flush=True, + ) abort(401) except Exception as e: - print('Authorization failed. Expection encountered: "{}".'.format( - e), flush=True) + print( + 'Authorization failed. Expection encountered: "{}".'.format(e), + flush=True, + ) abort(401) def run(self): @@ -249,10 +278,10 @@ def run(self): """ app = Flask(__name__) - app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER - app.config['SECRET_KEY'] = self.SECRET_KEY + app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER + app.config["SECRET_KEY"] = self.SECRET_KEY - @app.route('/') + @app.route("/") def index(): """ @@ -260,7 +289,7 @@ def index(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Render template not_configured_template = self.check_configured() @@ -268,29 +297,37 @@ def index(): template = not_configured_template else: events = self.control.get_events() - message = request.args.get('message', None) - message_type = request.args.get('message_type', None) - template = render_template('events.html', client=self.name, state=ReducerStateToString(self.control.state()), - events=events, - logs=None, refresh=True, configured=True, message=message, message_type=message_type) + message = request.args.get("message", None) + message_type = request.args.get("message_type", None) + template = render_template( + "events.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + events=events, + logs=None, + refresh=True, + configured=True, + message=message, + message_type=message_type, + ) # Set token cookie in response if needed response = make_response(template) - if 'token' in request.args: # args auth - response.set_cookie('fedn_token', str(request.args['token'])) + if "token" in request.args: # args auth + response.set_cookie("fedn_token", str(request.args["token"])) # Return response return response - @app.route('/status') + @app.route("/status") def status(): """ :return: """ - return {'state': ReducerStateToString(self.control.state())} + return {"state": ReducerStateToString(self.control.state())} - @app.route('/netgraph') + @app.route("/netgraph") def netgraph(): """ Creates nodes and edges for network graph @@ -298,16 +335,18 @@ def netgraph(): :return: nodes and edges as keys :rtype: dict """ - result = {'nodes': [], 'edges': []} - - result['nodes'].append({ - "id": "reducer", - "label": "Reducer", - "role": 'reducer', - "status": 'active', - "name": 'reducer', # TODO: get real host name - "type": 'reducer', - }) + result = {"nodes": [], "edges": []} + + result["nodes"].append( + { + "id": "reducer", + "label": "Reducer", + "role": "reducer", + "status": "active", + "name": "reducer", # TODO: get real host name + "type": "reducer", + } + ) combiner_info = combiner_status() client_info = client_status() @@ -318,49 +357,55 @@ def netgraph(): for combiner in combiner_info: print("combiner info {}".format(combiner_info), flush=True) try: - result['nodes'].append({ - "id": combiner['name'], # "n{}".format(count), - "label": "Combiner ({} clients)".format(combiner['nr_active_clients']), - "role": 'combiner', - "status": 'active', # TODO: Hard-coded, combiner_info does not contain status - "name": combiner['name'], - "type": 'combiner', - }) + result["nodes"].append( + { + "id": combiner["name"], # "n{}".format(count), + "label": "Combiner ({} clients)".format( + combiner["nr_active_clients"] + ), + "role": "combiner", + "status": "active", # TODO: Hard-coded, combiner_info does not contain status + "name": combiner["name"], + "type": "combiner", + } + ) except Exception as err: print(err) - for client in client_info['active_clients']: + for client in client_info["active_clients"]: try: - if client['status'] != 'offline': - result['nodes'].append({ - "id": str(client['_id']), - "label": "Client", - "role": client['role'], - "status": client['status'], - "name": client['name'], - "combiner": client['combiner'], - "type": 'client', - }) + if client["status"] != "offline": + result["nodes"].append( + { + "id": str(client["_id"]), + "label": "Client", + "role": client["role"], + "status": client["status"], + "name": client["name"], + "combiner": client["combiner"], + "type": "client", + } + ) except Exception as err: print(err) count = 0 - for node in result['nodes']: + for node in result["nodes"]: try: - if node['type'] == 'combiner': - result['edges'].append( + if node["type"] == "combiner": + result["edges"].append( { "id": "e{}".format(count), - "source": node['id'], - "target": 'reducer', + "source": node["id"], + "target": "reducer", } ) - elif node['type'] == 'client': - result['edges'].append( + elif node["type"] == "client": + result["edges"].append( { "id": "e{}".format(count), - "source": node['combiner'], - "target": node['id'], + "source": node["combiner"], + "target": node["id"], } ) except Exception: @@ -368,59 +413,75 @@ def netgraph(): count = count + 1 return result - @app.route('/networkgraph') + @app.route("/networkgraph") def network_graph(): - try: plot = Plot(self.control.statestore) result = netgraph() - df_nodes = pd.DataFrame(result['nodes']) - df_edges = pd.DataFrame(result['edges']) + df_nodes = pd.DataFrame(result["nodes"]) + df_edges = pd.DataFrame(result["edges"]) graph = plot.make_netgraph_plot(df_edges, df_nodes) return json.dumps(json_item(graph, "myplot")) except Exception: raise # return '' - @app.route('/events') + @app.route("/events") def events(): """ :return: """ + response = self.control.get_events() + events = [] + + result = response["result"] + + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) + json_docs = [] for doc in self.control.get_events(): json_doc = json.dumps(doc, default=json_util.default) json_docs.append(json_doc) json_docs.reverse() - return {'events': json_docs} - @app.route('/add') + return {"events": json_docs} + + @app.route("/add") def add(): - """ Add a combiner to the network. """ + """Add a combiner to the network.""" print("Adding combiner to network:", flush=True) if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry'}) - - name = request.args.get('name', None) - address = str(request.args.get('address', None)) - fqdn = str(request.args.get('fqdn', None)) - port = request.args.get('port', None) - secure_grpc = request.args.get('secure', None) - - if port is None or address is None or name is None or secure_grpc is None: + return jsonify({"status": "retry"}) + + name = request.args.get("name", None) + address = str(request.args.get("address", None)) + fqdn = str(request.args.get("fqdn", None)) + port = request.args.get("port", None) + secure_grpc = request.args.get("secure", None) + + if ( + port is None + or address is None + or name is None + or secure_grpc is None + ): return "Please specify correct parameters." # Try to retrieve combiner from db combiner = self.control.network.get_combiner(name) if not combiner: - if secure_grpc == 'True': + if secure_grpc == "True": certificate, key = self.certificate_manager.get_or_create( - address).get_keypair_raw() + address + ).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -436,23 +497,24 @@ def add(): port=port, certificate=copy.deepcopy(certificate), key=copy.deepcopy(key), - ip=request.remote_addr) + ip=request.remote_addr, + ) self.control.network.add_combiner(combiner) combiner = self.control.network.get_combiner(name) ret = { - 'status': 'added', - 'storage': self.control.statestore.get_storage_backend(), - 'statestore': self.control.statestore.get_config(), - 'certificate': combiner.get_certificate(), - 'key': combiner.get_key() + "status": "added", + "storage": self.control.statestore.get_storage_backend(), + "statestore": self.control.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), } return jsonify(ret) - @app.route('/eula', methods=['GET', 'POST']) + @app.route("/eula", methods=["GET", "POST"]) def eula(): """ @@ -461,9 +523,9 @@ def eula(): for r in request.headers: print("header contains: {}".format(r), flush=True) - return render_template('eula.html', configured=True) + return render_template("eula.html", configured=True) - @app.route('/models', methods=['GET', 'POST']) + @app.route("/models", methods=["GET", "POST"]) def models(): """ @@ -471,13 +533,12 @@ def models(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) - if request.method == 'POST': + if request.method == "POST": # upload seed file - uploaded_seed = request.files['seed'] + uploaded_seed = request.files["seed"] if uploaded_seed: - a = BytesIO() a.seek(0, 0) uploaded_seed.seek(0) @@ -500,26 +561,34 @@ def models(): box_plot = None print(e, flush=True) - h_latest_model_id = self.control.get_latest_model() + h_latest_model_id = self.statestore.get_latest_model() model_info = self.control.get_model_info() - return render_template('models.html', box_plot=box_plot, metrics=valid_metrics, h_latest_model_id=h_latest_model_id, seed=True, - model_info=model_info, configured=True) + return render_template( + "models.html", + box_plot=box_plot, + metrics=valid_metrics, + h_latest_model_id=h_latest_model_id, + seed=True, + model_info=model_info, + configured=True, + ) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/delete_model_trail', methods=['GET', 'POST']) + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """ :return: """ - if request.method == 'POST': - + if request.method == "POST": statestore_config = self.control.statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], + statestore_config["network_id"], + ) try: self.control.drop_models() except Exception: @@ -527,28 +596,28 @@ def delete_model_trail(): # drop objects in minio self.control.delete_bucket_objects() - return redirect(url_for('models')) + return redirect(url_for("models")) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/drop_control', methods=['GET', 'POST']) + @app.route("/drop_control", methods=["GET", "POST"]) def drop_control(): """ :return: """ - if request.method == 'POST': + if request.method == "POST": self.control.statestore.drop_control() - return redirect(url_for('control')) - return redirect(url_for('control')) + return redirect(url_for("control")) + return redirect(url_for("control")) # http://localhost:8090/control?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548 - @app.route('/control', methods=['GET', 'POST']) + @app.route("/control", methods=["GET", "POST"]) def control(): - """ Main page for round control. Configure, start and stop training sessions. """ + """Main page for round control. Configure, start and stop training sessions.""" # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -559,105 +628,145 @@ def control(): if self.remote_compute_package: try: - self.current_compute_context = self.control.get_compute_package_name() + self.current_compute_context = ( + self.control.get_compute_package_name() + ) except Exception: self.current_compute_context = None else: self.current_compute_context = "None:Local" if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=state, refresh=refresh, message="Reducer is in monitoring state")) - - if request.method == 'POST': + url_for( + "index", + state=state, + refresh=refresh, + message="Reducer is in monitoring state", + ) + ) + + if request.method == "POST": # Get session configuration - round_timeout = float(request.form.get('timeout', 180)) - rounds = int(request.form.get('rounds', 1)) - delete_models = request.form.get('delete_models', True) - task = (request.form.get('task', '')) - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + round_timeout = float(request.form.get("timeout", 180)) + buffer_size = int(request.form.get("buffer_size", -1)) + rounds = int(request.form.get("rounds", 1)) + delete_models = request.form.get("delete_models", True) + task = request.form.get("task", "") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # checking if there are enough clients connected to start! clients_available = 0 for combiner in self.control.network.get_combiners(): try: combiner_state = combiner.report() - nac = combiner_state['nr_active_clients'] + nac = combiner_state["nr_active_clients"] clients_available = clients_available + int(nac) except Exception: pass if clients_available < clients_required: - return redirect(url_for('index', state=state, - message="Not enough clients available to start rounds! " - "check combiner client capacity", - message_type='warning')) + return redirect( + url_for( + "index", + state=state, + message="Not enough clients available to start rounds! " + "check combiner client capacity", + message_type="warning", + ) + ) - validate = request.form.get('validate', False) - if validate == 'False': + validate = request.form.get("validate", False) + if validate == "False": validate = False - helper_type = request.form.get('helper', 'keras') + helper_type = request.form.get("helper", "keras") # self.control.statestore.set_framework(helper_type) - latest_model_id = self.control.get_latest_model() + latest_model_id = self.statestore.get_latest_model() + + config = { + "round_timeout": round_timeout, + "buffer_size": buffer_size, + "model_id": latest_model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": task, + "validate": validate, + "helper_type": helper_type, + } + + threading.Thread( + target=self.control.session, args=(config,) + ).start() - config = {'round_timeout': round_timeout, 'model_id': latest_model_id, - 'rounds': rounds, 'delete_models_storage': delete_models, - 'clients_required': clients_required, - 'clients_requested': clients_requested, 'task': task, - 'validate': validate, 'helper_type': helper_type} - - threading.Thread(target=self.control.session, - args=(config,)).start() - - return redirect(url_for('index', state=state, refresh=refresh, message="Sent execution plan.", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=state, + refresh=refresh, + message="Sent execution plan.", + message_type="SUCCESS", + ) + ) else: seed_model_id = None latest_model_id = None try: - seed_model_id = self.control.get_first_model()[0] - latest_model_id = self.control.get_latest_model() + seed_model_id = self.statestore.get_initial_model() + latest_model_id = self.statestore.get_latest_model() except Exception: pass - return render_template('index.html', latest_model_id=latest_model_id, - compute_package=self.current_compute_context, - seed_model_id=seed_model_id, - helper=self.control.statestore.get_helper(), validate=True, configured=True) - - @app.route('/assign') + return render_template( + "index.html", + latest_model_id=latest_model_id, + compute_package=self.current_compute_context, + seed_model_id=seed_model_id, + helper=self.control.statestore.get_helper(), + validate=True, + configured=True, + ) + + @app.route("/assign") def assign(): - """Handle client assignment requests. """ + """Handle client assignment requests.""" if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) response = self.check_configured_response() if response: return response - name = request.args.get('name', None) - combiner_preferred = request.args.get('combiner', None) + name = request.args.get("name", None) + combiner_preferred = request.args.get("combiner", None) if combiner_preferred: - combiner = self.control.network.get_combiner(combiner_preferred) + combiner = self.control.network.get_combiner( + combiner_preferred + ) else: combiner = self.control.network.find_available_combiner() if combiner is None: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Failed to assign to a combiner, try again later."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Failed to assign to a combiner, try again later.", + } + ) client = { - 'name': name, - 'combiner_preferred': combiner_preferred, - 'combiner': combiner.name, - 'ip': request.remote_addr, - 'status': 'available' + "name": name, + "combiner_preferred": combiner_preferred, + "combiner": combiner.name, + "ip": request.remote_addr, + "status": "available", } # Add client to database @@ -666,25 +775,25 @@ def assign(): # Return connection information to client if combiner.certificate: cert_b64 = base64.b64encode(combiner.certificate) - cert = str(cert_b64).split('\'')[1] + cert = str(cert_b64).split("'")[1] else: cert = None response = { - 'status': 'assigned', - 'host': combiner.address, - 'fqdn': combiner.fqdn, - 'package': self.package, - 'ip': combiner.ip, - 'port': combiner.port, - 'certificate': cert, - 'model_type': self.control.statestore.get_helper() + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": self.package, + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "model_type": self.control.statestore.get_helper(), } return jsonify(response) def combiner_status(): - """ Get current status reports from all combiners registered in the network. + """Get current status reports from all combiners registered in the network. :return: """ @@ -709,67 +818,90 @@ def client_status(): all_active_validators = [] for client in combiner_info: - active_trainers_str = client['active_trainers'] - active_validators_str = client['active_validators'] + active_trainers_str = client["active_trainers"] + active_validators_str = client["active_validators"] active_trainers_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_trainers_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_trainers_str # noqa: W605 + ).replace( + "name:", " " + ) active_validators_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_validators_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_validators_str # noqa: W605 + ).replace( + "name:", " " + ) all_active_trainers.extend( - ' '.join(active_trainers_str.split(" ")).split()) + " ".join(active_trainers_str.split(" ")).split() + ) all_active_validators.extend( - ' '.join(active_validators_str.split(" ")).split()) + " ".join(active_validators_str.split(" ")).split() + ) active_trainers_list = [ - client for client in client_info if client['name'] in all_active_trainers] + client + for client in client_info + if client["name"] in all_active_trainers + ] active_validators_list = [ - cl for cl in client_info if cl['name'] in all_active_validators] + cl + for cl in client_info + if cl["name"] in all_active_validators + ] all_clients = [cl for cl in client_info] for client in all_clients: - status = 'offline' - role = 'None' + status = "offline" + role = "None" self.control.network.update_client_data( - client, status, role) + client, status, role + ) - all_active_clients = active_validators_list + active_trainers_list + all_active_clients = ( + active_validators_list + active_trainers_list + ) for client in all_active_clients: - status = 'active' - if client in active_trainers_list and client in active_validators_list: - role = 'trainer-validator' + status = "active" + if ( + client in active_trainers_list + and client in active_validators_list + ): + role = "trainer-validator" elif client in active_trainers_list: - role = 'trainer' + role = "trainer" elif client in active_validators_list: - role = 'validator' + role = "validator" else: - role = 'unknown' + role = "unknown" self.control.network.update_client_data( - client, status, role) - - return {'active_clients': all_clients, - 'active_trainers': active_trainers_list, - 'active_validators': active_validators_list - } + client, status, role + ) + + return { + "active_clients": all_clients, + "active_trainers": active_trainers_list, + "active_validators": active_validators_list, + } except Exception: pass - return {'active_clients': [], - 'active_trainers': [], - 'active_validators': [] - } + return { + "active_clients": [], + "active_trainers": [], + "active_validators": [], + } - @app.route('/metric_type', methods=['GET', 'POST']) + @app.route("/metric_type", methods=["GET", "POST"]) def change_features(): """ :return: """ - feature = request.args['selected'] + feature = request.args["selected"] plot = Plot(self.control.statestore) graphJSON = plot.create_box_plot(feature) return graphJSON - @app.route('/dashboard') + @app.route("/dashboard") def dashboard(): """ @@ -777,7 +909,7 @@ def dashboard(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -791,16 +923,18 @@ def dashboard(): clients_plot = plot.create_client_plot() client_histogram_plot = plot.create_client_histogram_plot() - return render_template('dashboard.html', show_plot=True, - table_plot=table_plot, - timeline_plot=timeline_plot, - clients_plot=clients_plot, - client_histogram_plot=client_histogram_plot, - combiners_plot=combiners_plot, - configured=True - ) - - @app.route('/network') + return render_template( + "dashboard.html", + show_plot=True, + table_plot=table_plot, + timeline_plot=timeline_plot, + clients_plot=clients_plot, + client_histogram_plot=client_histogram_plot, + combiners_plot=combiners_plot, + configured=True, + ) + + @app.route("/network") def network(): """ @@ -808,7 +942,7 @@ def network(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -819,17 +953,19 @@ def network(): combiner_info = combiner_status() active_clients = client_status() # print(combiner_info, flush=True) - return render_template('network.html', network_plot=True, - round_time_plot=round_time_plot, - mem_cpu_plot=mem_cpu_plot, - combiner_info=combiner_info, - active_clients=active_clients['active_clients'], - active_trainers=active_clients['active_trainers'], - active_validators=active_clients['active_validators'], - configured=True - ) - - @app.route('/config/download', methods=['GET']) + return render_template( + "network.html", + network_plot=True, + round_time_plot=round_time_plot, + mem_cpu_plot=mem_cpu_plot, + combiner_info=combiner_info, + active_clients=active_clients["active_clients"], + active_trainers=active_clients["active_trainers"], + active_validators=active_clients["active_validators"], + configured=True, + ) + + @app.route("/config/download", methods=["GET"]) def config_download(): """ @@ -837,8 +973,8 @@ def config_download(): """ chk_string = "" name = self.control.get_compute_package_name() - if name is None or name == '': - chk_string = '' + if name is None or name == "": + chk_string = "" else: file_path = os.path.join(UPLOAD_FOLDER, name) print("trying to get {}".format(file_path)) @@ -846,7 +982,7 @@ def config_download(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" chk_string = "checksum: {}".format(sum) network_id = self.network_id @@ -855,20 +991,24 @@ def config_download(): ctx = """network_id: {network_id} discover_host: {discover_host} discover_port: {discover_port} -{chk_string}""".format(network_id=network_id, - discover_host=discover_host, - discover_port=discover_port, - chk_string=chk_string) +{chk_string}""".format( + network_id=network_id, + discover_host=discover_host, + discover_port=discover_port, + chk_string=chk_string, + ) obj = BytesIO() - obj.write(ctx.encode('UTF-8')) + obj.write(ctx.encode("UTF-8")) obj.seek(0) - return send_file(obj, - as_attachment=True, - download_name='client.yaml', - mimetype='application/x-yaml') - - @app.route('/context', methods=['GET', 'POST']) + return send_file( + obj, + as_attachment=True, + download_name="client.yaml", + mimetype="application/x-yaml", + ) + + @app.route("/context", methods=["GET", "POST"]) def context(): """ @@ -876,78 +1016,85 @@ def context(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # if reset is not empty then allow context re-set - reset = request.args.get('reset', None) + reset = request.args.get("reset", None) if reset: - return render_template('context.html') + return render_template("context.html") - if request.method == 'POST': + if request.method == "POST": + if "file" not in request.files: + flash("No file part") + return redirect(url_for("context")) - if 'file' not in request.files: - flash('No file part') - return redirect(url_for('context')) - - file = request.files['file'] - helper_type = request.form.get('helper', 'kerashelper') + file = request.files["file"] + helper_type = request.form.get("helper", "kerashelper") # if user does not select file, browser also # submit an empty part without filename - if file.filename == '': - flash('No selected file') - return redirect(url_for('context')) + if file.filename == "": + flash("No selected file") + return redirect(url_for("context")) if file and allowed_file(file.filename): filename = secure_filename(file.filename) file_path = os.path.join( - app.config['UPLOAD_FOLDER'], filename) + app.config["UPLOAD_FOLDER"], filename + ) file.save(file_path) - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): return "Not allowed to change context while execution is ongoing." self.control.set_compute_package(filename, file_path) self.control.statestore.set_helper(helper_type) - return redirect(url_for('control')) + return redirect(url_for("control")) - name = request.args.get('name', '') + name = request.args.get("name", "") - if name == '': + if name == "": name = self.control.get_compute_package_name() - if name is None or name == '': - return render_template('context.html') + if name is None or name == "": + return render_template("context.html") # There is a potential race condition here, if one client requests a package and at # the same time another one triggers a fetch from Minio and writes to disk. try: mutex = Lock() mutex.acquire() - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: try: data = self.control.get_compute_package(name) - file_path = os.path.join(app.config['UPLOAD_FOLDER'], name) - with open(file_path, 'wb') as fh: + file_path = os.path.join(app.config["UPLOAD_FOLDER"], name) + with open(file_path, "wb") as fh: fh.write(data) - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: raise finally: mutex.release() - return render_template('context.html') + return render_template("context.html") - @app.route('/checksum', methods=['GET', 'POST']) + @app.route("/checksum", methods=["GET", "POST"]) def checksum(): """ :return: """ # sum = '' - name = request.args.get('name', None) - if name == '' or name is None: + name = request.args.get("name", None) + if name == "" or name is None: name = self.control.get_compute_package_name() - if name is None or name == '': + if name is None or name == "": return jsonify({}) file_path = os.path.join(UPLOAD_FOLDER, name) @@ -956,13 +1103,13 @@ def checksum(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" - data = {'checksum': sum} + data = {"checksum": sum} return jsonify(data) - @app.route('/infer', methods=['POST']) + @app.route("/infer", methods=["POST"]) def infer(): """ @@ -970,7 +1117,7 @@ def infer(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Check configured not_configured = self.check_configured() @@ -980,7 +1127,9 @@ def infer(): # Check compute context if self.remote_compute_context: try: - self.current_compute_context = self.control.get_compute_package() + self.current_compute_context = ( + self.control.get_compute_package() + ) except Exception as e: print(e, flush=True) self.current_compute_context = None @@ -990,27 +1139,43 @@ def infer(): # Redirect if in monitoring state if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Reducer is in monitoring state")) + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Reducer is in monitoring state", + ) + ) # POST params - timeout = int(request.form.get('timeout', 180)) - helper_type = request.form.get('helper', 'keras') - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + timeout = int(request.form.get("timeout", 180)) + helper_type = request.form.get("helper", "keras") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # Start inference request - config = {'round_timeout': timeout, - 'model_id': self.control.get_latest_model(), - 'clients_required': clients_required, - 'clients_requested': clients_requested, - 'task': 'inference', - 'helper_type': helper_type} - threading.Thread(target=self.control.infer_instruct, - args=(config,)).start() + config = { + "round_timeout": timeout, + "model_id": self.statestore.get_latest_model(), + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": "inference", + "helper_type": helper_type, + } + threading.Thread( + target=self.control.infer_instruct, args=(config,) + ).start() # Redirect - return redirect(url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Sent execution plan (inference).", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Sent execution plan (inference).", + message_type="SUCCESS", + ) + ) if not self.host: bind = "0.0.0.0" diff --git a/fedn/fedn/network/dashboard/templates/events.html b/fedn/fedn/network/dashboard/templates/events.html index d3c34beb5..1fb5fac74 100644 --- a/fedn/fedn/network/dashboard/templates/events.html +++ b/fedn/fedn/network/dashboard/templates/events.html @@ -3,41 +3,44 @@ {% block content %} -
-
-
Events
-
-
- - - - + + + -
- -
+ }); + +
+
+ -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/fedn/fedn/network/dashboard/templates/index.html b/fedn/fedn/network/dashboard/templates/index.html index a4182130d..4ac2d182b 100644 --- a/fedn/fedn/network/dashboard/templates/index.html +++ b/fedn/fedn/network/dashboard/templates/index.html @@ -305,6 +305,16 @@
New session
+ + +
+ +
+ +
diff --git a/fedn/fedn/network/loadbalancer/__init__.py b/fedn/fedn/network/loadbalancer/__init__.py index e69de29bb..d0e44bf3c 100644 --- a/fedn/fedn/network/loadbalancer/__init__.py +++ b/fedn/fedn/network/loadbalancer/__init__.py @@ -0,0 +1 @@ +""" The loadbalancer package is responsible for loadbalancing the clients to the combiners. """ diff --git a/fedn/fedn/network/loadbalancer/firstavailable.py b/fedn/fedn/network/loadbalancer/firstavailable.py index 6ffed0806..9d44d3fbd 100644 --- a/fedn/fedn/network/loadbalancer/firstavailable.py +++ b/fedn/fedn/network/loadbalancer/firstavailable.py @@ -2,6 +2,11 @@ class LeastPacked(LoadBalancerBase): + """ Load balancer that selects the first available combiner. + + :param network: A handle to the network. + :type network: class: `fedn.network.api.network.Network` + """ def __init__(self, network): super().__init__(network) diff --git a/fedn/fedn/network/loadbalancer/leastpacked.py b/fedn/fedn/network/loadbalancer/leastpacked.py index 588e6b491..9e4aaba0d 100644 --- a/fedn/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/fedn/network/loadbalancer/leastpacked.py @@ -3,6 +3,11 @@ class LeastPacked(LoadBalancerBase): + """ Load balancer that selects the combiner with the least number of attached clients. + + :param network: A handle to the network. + :type network: class: `fedn.network.api.network.Network` + """ def __init__(self, network): super().__init__(network) diff --git a/fedn/fedn/network/loadbalancer/loadbalancerbase.py b/fedn/fedn/network/loadbalancer/loadbalancerbase.py index cc39c4200..ff1edfa9b 100644 --- a/fedn/fedn/network/loadbalancer/loadbalancerbase.py +++ b/fedn/fedn/network/loadbalancer/loadbalancerbase.py @@ -2,7 +2,11 @@ class LoadBalancerBase(ABC): - """ Abstract base class for load balancers. """ + """ Abstract base class for load balancers. + + :param network: A handle to the network. + :type network: class: `fedn.network.api.network.Network` + """ def __init__(self, network): """ """ diff --git a/fedn/fedn/network/reducer.py b/fedn/fedn/network/reducer.py index 47f6aca6d..d61186864 100644 --- a/fedn/fedn/network/reducer.py +++ b/fedn/fedn/network/reducer.py @@ -23,16 +23,12 @@ class MissingReducerConfiguration(Exception): class Reducer: """ A class used to instantiate the Reducer service. - Start Reducer service. + :param statestore: The backend statestore object. + :type statestore: :class:`fedn.network.statestore.statestorebase.StateStoreBase` """ def __init__(self, statestore): - """ - Parameters - ---------- - statestore: dict - The backend statestore object. - """ + """ Constructor""" self.statestore = statestore config = self.statestore.get_reducer() @@ -52,7 +48,7 @@ def __init__(self, statestore): self.control = Control(self.statestore) self.rest = ReducerRestService( - config, self.control, self.certificate_manager) + config, self.control, self.statestore, self.certificate_manager) def run(self): """Start REST service and control loop.""" diff --git a/fedn/fedn/network/state.py b/fedn/fedn/network/state.py index ab1e33e0f..9d18bc924 100644 --- a/fedn/fedn/network/state.py +++ b/fedn/fedn/network/state.py @@ -2,6 +2,7 @@ class ReducerState(Enum): + """ Enum for representing the state of a reducer.""" setup = 1 idle = 2 instructing = 3 @@ -9,10 +10,12 @@ class ReducerState(Enum): def ReducerStateToString(state): - """ + """ Convert ReducerState to string. - :param state: - :return: + :param state: The state. + :type state: :class:`fedn.network.state.ReducerState` + :return: The state as string. + :rtype: str """ if state == ReducerState.setup: return "setup" @@ -27,10 +30,12 @@ def ReducerStateToString(state): def StringToReducerState(state): - """ + """ Convert string to ReducerState. - :param state: - :return: + :param state: The state as string. + :type state: str + :return: The state. + :rtype: :class:`fedn.network.state.ReducerState` """ if state == "setup": return ReducerState.setup diff --git a/fedn/fedn/network/statestore/mongostatestore.py b/fedn/fedn/network/statestore/mongostatestore.py index b9f9d74e1..19d514f59 100644 --- a/fedn/fedn/network/statestore/mongostatestore.py +++ b/fedn/fedn/network/statestore/mongostatestore.py @@ -2,7 +2,6 @@ from datetime import datetime import pymongo -import yaml from fedn.common.storage.db.mongo import connect_to_mongodb from fedn.network.state import ReducerStateToString, StringToReducerState @@ -11,11 +10,18 @@ class MongoStateStore(StateStoreBase): + """Statestore implementation using MongoDB. + + :param network_id: The network id. + :type network_id: str + :param config: The statestore configuration. + :type config: dict + :param defaults: The default configuration. Given by config/settings-reducer.yaml.template + :type defaults: dict """ - """ - - def __init__(self, network_id, config, defaults=None): + def __init__(self, network_id, config, model_storage_config): + """Constructor.""" self.__inited = False try: self.config = config @@ -23,19 +29,19 @@ def __init__(self, network_id, config, defaults=None): self.mdb = connect_to_mongodb(self.config, self.network_id) # FEDn network - self.network = self.mdb['network'] - self.reducer = self.network['reducer'] - self.combiners = self.network['combiners'] - self.clients = self.network['clients'] - self.storage = self.network['storage'] + self.network = self.mdb["network"] + self.reducer = self.network["reducer"] + self.combiners = self.network["combiners"] + self.clients = self.network["clients"] + self.storage = self.network["storage"] # Control - self.control = self.mdb['control'] - self.package = self.control['package'] - self.state = self.control['state'] - self.model = self.control['model'] - self.sessions = self.control['sessions'] - self.rounds = self.control['rounds'] + self.control = self.mdb["control"] + self.package = self.control["package"] + self.state = self.control["state"] + self.model = self.control["model"] + self.sessions = self.control["sessions"] + self.rounds = self.control["rounds"] # Logging self.status = self.control["status"] @@ -51,207 +57,360 @@ def __init__(self, network_id, config, defaults=None): self.clients = None raise - if defaults: - with open(defaults, 'r') as file: - try: - settings = dict(yaml.safe_load(file)) - print(settings, flush=True) - - # Control settings - if "control" in settings and settings["control"]: - control = settings['control'] - try: - self.transition(str(control['state'])) - except KeyError: - self.transition("idle") - - if "model" in control: - if not self.get_latest(): - self.set_latest(str(control['model'])) - else: - print( - "Model trail already initialized - refusing to overwrite from config. Purge model trail if you want to reseed the system.", - flush=True) - - if "context" in control: - print("Setting filepath to {}".format( - control['context']), flush=True) - # TODO Fix the ugly latering of indirection due to a bug in secure_filename returning an object with filename as attribute - # TODO fix with unboxing of value before storing and where consuming. - self.control.config.update_one({'key': 'package'}, - {'$set': {'filename': control['context']}}, True) - if "helper" in control: - # self.set_framework(control['helper']) - pass - - round_config = {'timeout': 180, 'validate': True} - try: - round_config['timeout'] = control['timeout'] - except Exception: - pass - - try: - round_config['validate'] = control['validate'] - except Exception: - pass - - # Storage settings - self.set_storage_backend(settings['storage']) - - self.__inited = True - except yaml.YAMLError as e: - print(e) + # Storage settings + self.set_storage_backend(model_storage_config) + self.__inited = True def is_inited(self): - """ Check if the statestore is intialized. + """Check if the statestore is intialized. - :return: + :return: True if initialized, else False. + :rtype: bool """ return self.__inited def get_config(self): """Retrive the statestore config. - :return: + :return: The statestore config. + :rtype: dict """ data = { - 'type': 'MongoDB', - 'mongo_config': self.config, - 'network_id': self.network_id + "type": "MongoDB", + "mongo_config": self.config, + "network_id": self.network_id, } return data def state(self): - """ + """Get the current state. - :return: + :return: The current state. + :rtype: str """ - return StringToReducerState(self.state.find_one()['current_state']) + return StringToReducerState(self.state.find_one()["current_state"]) def transition(self, state): - """ + """Transition to a new state. - :param state: + :param state: The new state. + :type state: str :return: """ - old_state = self.state.find_one({'state': 'current_state'}) + old_state = self.state.find_one({"state": "current_state"}) if old_state != state: - return self.state.update_one({'state': 'current_state'}, {'$set': {'state': ReducerStateToString(state)}}, True) + return self.state.update_one( + {"state": "current_state"}, + {"$set": {"state": ReducerStateToString(state)}}, + True, + ) else: - print("Not updating state, already in {}".format( - ReducerStateToString(state))) - - def set_latest(self, model_id): + print( + "Not updating state, already in {}".format( + ReducerStateToString(state) + ) + ) + + def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): + """Get all sessions. + + :param limit: The maximum number of sessions to return. + :type limit: int + :param skip: The number of sessions to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: Dictionary of sessions in result (array of session objects) and count. """ - :param model_id: + result = None + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = self.sessions.find().limit(limit).skip(skip).sort( + sort_key, sort_order + ) + else: + result = self.sessions.find().sort( + sort_key, sort_order + ) + + count = self.sessions.count_documents({}) + + return { + "result": result, + "count": count, + } + + def get_session(self, session_id): + """Get session with id. + + :param session_id: The session id. + :type session_id: str + :return: The session. + :rtype: ObjectID """ + return self.sessions.find_one({"session_id": session_id}) - self.model.update_one({'key': 'current_model'}, { - '$set': {'model': model_id}}, True) - self.model.update_one({'key': 'model_trail'}, {'$push': {'model': model_id, 'committed_at': str(datetime.now())}}, - True) + def set_latest_model(self, model_id, session_id=None): + """Set the latest model id. - def get_first(self): - """ Return model_id for the latest model in the model_trail """ + :param model_id: The model id. + :type model_id: str + :return: + """ + + committed_at = datetime.now() + + self.model.insert_one( + { + "key": "models", + "model": model_id, + "session_id": session_id, + "committed_at": committed_at, + } + ) + + self.model.update_one( + {"key": "current_model"}, {"$set": {"model": model_id}}, True + ) + self.model.update_one( + {"key": "model_trail"}, + { + "$push": { + "model": model_id, + "committed_at": str(committed_at), + } + }, + True, + ) + + def get_initial_model(self): + """Return model_id for the initial model in the model trail + + :return: The initial model id. None if no model is found. + :rtype: str + """ - ret = self.model.find_one({'key': 'model_trail'}, sort=[ - ("committed_at", pymongo.ASCENDING)]) - if ret is None: + result = self.model.find_one( + {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] + ) + if result is None: return None try: - model_id = ret['model'] - if model_id == '' or model_id == ' ': # ugly check for empty string + model_id = result["model"] + if model_id == "" or model_id == " ": return None - return model_id + return model_id[0] except (KeyError, IndexError): return None - def get_latest(self): - """ Return model_id for the latest model in the model_trail """ - ret = self.model.find_one({'key': 'current_model'}) - if ret is None: + def get_latest_model(self): + """Return model_id for the latest model in the model_trail + + :return: The latest model id. None if no model is found. + :rtype: str + """ + result = self.model.find_one({"key": "current_model"}) + if result is None: return None try: - model_id = ret['model'] - if model_id == '' or model_id == ' ': # ugly check for empty string + model_id = result["model"] + if model_id == "" or model_id == " ": return None return model_id except (KeyError, IndexError): return None def get_latest_round(self): - """ Get the id of the most recent round. """ + """Get the id of the most recent round. + + :return: The id of the most recent round. + :rtype: ObjectId + """ return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)]) def get_round(self, id): - """ Get round with id 'id'. """ + """Get round with id. + + :param id: id of round to get + :type id: int + :return: round with id, reducer and combiners + :rtype: ObjectId + """ + + return self.rounds.find_one({"round_id": str(id)}) + + def get_rounds(self): + """Get all rounds. + + :return: All rounds. + :rtype: ObjectId + """ - return self.rounds.find_one({'round_id': str(id)}) + return self.rounds.find() + + def get_validations(self, **kwargs): + """Get validations from the database. + + :param kwargs: query to filter validations + :type kwargs: dict + :return: validations matching query + :rtype: ObjectId + """ + + result = self.control.validations.find(kwargs) + return result def set_compute_package(self, filename): - """ Set the active compute package. + """Set the active compute package in statestore. - :param filename: + :param filename: The filename of the compute package. + :type filename: str + :return: True if successful. + :rtype: bool """ self.control.package.update_one( - {'key': 'active'}, {'$set': {'filename': filename}}, True) - self.control.package.update_one({'key': 'package_trail'}, - {'$push': {'filename': filename, 'committed_at': str(datetime.now())}}, True) + {"key": "active"}, + { + "$set": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) + self.control.package.update_one( + {"key": "package_trail"}, + { + "$push": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) + return True def get_compute_package(self): - """ Get the active compute package. + """Get the active compute package. - :return: + :return: The active compute package. + :rtype: ObjectID """ - ret = self.control.package.find({'key': 'active'}) + ret = self.control.package.find({"key": "active"}) try: retcheck = ret[0] - if retcheck is None or retcheck == '' or retcheck == ' ': # ugly check for empty string + if ( + retcheck is None or retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None def set_helper(self, helper): - """ + """Set the active helper package in statestore. - :param helper: + :param helper: The name of the helper package. See helper.py for available helpers. + :type helper: str + :return: """ - self.control.package.update_one({'key': 'active'}, - {'$set': {'helper': helper}}, True) + self.control.package.update_one( + {"key": "active"}, {"$set": {"helper": helper}}, True + ) def get_helper(self): - """ + """Get the active helper package. - :return: + :return: The active helper set for the package. + :rtype: str """ - ret = self.control.package.find_one({'key': 'active'}) + ret = self.control.package.find_one({"key": "active"}) # if local compute package used, then 'package' is None # if not ret: # get framework from round_config instead # ret = self.control.config.find_one({'key': 'round_config'}) try: - retcheck = ret['helper'] - if retcheck == '' or retcheck == ' ': # ugly check for empty string + retcheck = ret["helper"] + if ( + retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None - def get_model_info(self): + def list_models( + self, + session_id=None, + limit=None, + skip=None, + sort_key="committed_at", + sort_order=pymongo.DESCENDING, + ): + """List all models in the statestore. + + :param session_id: The session id. + :type session_id: str + :param limit: The maximum number of models to return. + :type limit: int + :param skip: The number of models to skip. + :type skip: int + :return: List of models. + :rtype: list """ + result = None - :return: + find_option = ( + {"key": "models"} + if session_id is None + else {"key": "models", "session_id": session_id} + ) + + projection = {"_id": False, "key": False} + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = ( + self.model.find(find_option, projection) + .limit(limit) + .skip(skip) + .sort(sort_key, sort_order) + ) + + else: + result = self.model.find(find_option, projection).sort( + sort_key, sort_order + ) + + count = self.model.count_documents(find_option) + + return { + "result": result, + "count": count, + } + + def get_model_trail(self): + """Get the model trail. + + :return: dictionary of model_id: committed_at + :rtype: dict """ - ret = self.model.find_one({'key': 'model_trail'}) + result = self.model.find_one({"key": "model_trail"}) try: - if ret: - committed_at = ret['committed_at'] - model = ret['model'] + if result is not None: + committed_at = result["committed_at"] + model = result["model"] model_dictionary = dict(zip(model, committed_at)) return model_dictionary else: @@ -259,100 +418,204 @@ def get_model_info(self): except (KeyError, IndexError): return None - def get_events(self): - """ + def get_events(self, **kwargs): + """Get events from the database. - :return: + :param kwargs: query to filter events + :type kwargs: dict + :return: events matching query + :rtype: ObjectId """ - ret = self.control.status.find({}) - return ret + # check if kwargs is empty + + result = None + count = None + projection = {"_id": False} + + if not kwargs: + result = self.control.status.find({}, projection).sort( + "timestamp", pymongo.DESCENDING + ) + count = self.control.status.count_documents({}) + else: + limit = kwargs.pop("limit", None) + skip = kwargs.pop("skip", None) + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = ( + self.control.status.find(kwargs, projection) + .sort("timestamp", pymongo.DESCENDING) + .limit(limit) + .skip(skip) + ) + else: + result = self.control.status.find(kwargs, projection).sort( + "timestamp", pymongo.DESCENDING + ) + + count = self.control.status.count_documents(kwargs) + + return { + "result": result, + "count": count, + } def get_storage_backend(self): - """ """ + """Get the storage backend. + + :return: The storage backend. + :rtype: ObjectID + """ try: ret = self.storage.find( - {'status': 'enabled'}, projection={'_id': False}) + {"status": "enabled"}, projection={"_id": False} + ) return ret[0] except (KeyError, IndexError): return None def set_storage_backend(self, config): - """ """ + """Set the storage backend. + + :param config: The storage backend configuration. + :type config: dict + :return: + """ config = copy.deepcopy(config) - config['updated_at'] = str(datetime.now()) - config['status'] = 'enabled' + config["updated_at"] = str(datetime.now()) + config["status"] = "enabled" self.storage.update_one( - {'storage_type': config['storage_type']}, {'$set': config}, True) + {"storage_type": config["storage_type"]}, {"$set": config}, True + ) def set_reducer(self, reducer_data): - """ """ - reducer_data['updated_at'] = str(datetime.now()) - self.reducer.update_one({'name': reducer_data['name']}, { - '$set': reducer_data}, True) + """Set the reducer in the statestore. + + :param reducer_data: dictionary of reducer config. + :type reducer_data: dict + :return: + """ + reducer_data["updated_at"] = str(datetime.now()) + self.reducer.update_one( + {"name": reducer_data["name"]}, {"$set": reducer_data}, True + ) def get_reducer(self): - """ """ + """Get reducer.config. + + return: reducer config. + rtype: ObjectId + """ try: ret = self.reducer.find_one() return ret except Exception: return None - def list_combiners(self): - """ """ - try: - ret = self.combiners.find() - return list(ret) - except Exception: - return None - def get_combiner(self, name): - """ """ + """Get combiner by name. + + :param name: name of combiner to get. + :type name: str + :return: The combiner. + :rtype: ObjectId + """ try: - ret = self.combiners.find_one({'name': name}) + ret = self.combiners.find_one({"name": name}) return ret except Exception: return None - def get_combiners(self): - """ """ + def get_combiners(self, limit=None, skip=None, sort_key="updated_at", sort_order=pymongo.DESCENDING, projection={}): + """Get all combiners. + + :param limit: The maximum number of combiners to return. + :type limit: int + :param skip: The number of combiners to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :param projection: The projection. + :type projection: dict + :return: Dictionary of combiners in result and count. + :rtype: dict + """ + + result = None + count = None + try: - ret = self.combiners.find() - return list(ret) + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.combiners.find({}, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.combiners.find({}, projection).sort(sort_key, sort_order) + + count = self.combiners.count_documents({}) + except Exception: return None + return { + "result": result, + "count": count, + } + def set_combiner(self, combiner_data): - """ - Set or update combiner record. - combiner_data: dictionary, output of combiner.to_dict()) + """Set combiner in statestore. + + :param combiner_data: dictionary of combiner config + :type combiner_data: dict + :return: """ - combiner_data['updated_at'] = str(datetime.now()) - self.combiners.update_one({'name': combiner_data['name']}, { - '$set': combiner_data}, True) + combiner_data["updated_at"] = str(datetime.now()) + self.combiners.update_one( + {"name": combiner_data["name"]}, {"$set": combiner_data}, True + ) def delete_combiner(self, combiner): - """ Delete a combiner entry. """ + """Delete a combiner from statestore. + + :param combiner: name of combiner to delete. + :type combiner: str + :return: + """ try: - self.combiners.delete_one({'name': combiner}) + self.combiners.delete_one({"name": combiner}) except Exception: - print("WARNING, failed to delete combiner: {}".format( - combiner), flush=True) + print( + "WARNING, failed to delete combiner: {}".format(combiner), + flush=True, + ) def set_client(self, client_data): + """Set client in statestore. + + :param client_data: dictionary of client config. + :type client_data: dict + :return: """ - Set or update client record. - client_data: dictionarys - """ - client_data['updated_at'] = str(datetime.now()) - self.clients.update_one({'name': client_data['name']}, { - '$set': client_data}, True) + client_data["updated_at"] = str(datetime.now()) + self.clients.update_one( + {"name": client_data["name"]}, {"$set": client_data}, True + ) def get_client(self, name): - """ Retrive a client record by name. """ + """Get client by name. + + :param name: name of client to get. + :type name: str + :return: The client. None if not found. + :rtype: ObjectId + """ try: - ret = self.clients.find({'key': name}) + ret = self.clients.find({"key": name}) if list(ret) == []: return None else: @@ -360,23 +623,87 @@ def get_client(self, name): except Exception: return None - def list_clients(self): - """List all clients registered on the network. """ + def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", sort_order=pymongo.DESCENDING): + """List all clients registered on the network. + + :param limit: The maximum number of clients to return. + :type limit: int + :param skip: The number of clients to skip. + :type skip: int + :param status: online | offline + :type status: str + :param sort_key: The key to sort by. + """ + + result = None + count = None + try: - ret = self.clients.find() - return list(ret) - except Exception: - return None + find = {} if status is None else {"status": status} + projection = {"_id": False, "updated_at": False} - def update_client_status(self, client_data, status, role): + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.clients.find(find, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.clients.find(find, projection).sort(sort_key, sort_order) + + count = self.clients.count_documents(find) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return { + "result": result, + "count": count, + } + + def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DESCENDING): + """List all combiner data. + + :param combiners: list of combiners to get data for. + :type combiners: list + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: list of combiner data. + :rtype: list(ObjectId) """ - Set or update client status. - assign roles to the active clients (trainer, validator, trainer-validator) + + result = None + + try: + + pipeline = [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] if combiners is not None else [ + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] + + result = self.clients.aggregate(pipeline) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return result + + def update_client_status(self, client_data, status, role): + """Set or update client status. + + :param client_data: dictionary of client config. + :type client_data: dict + :param status: status of client. + :type status: str + :param role: role of client. + :type role: str + :return: """ - self.clients.update_one({"name": client_data['name']}, - {"$set": - { - "status": status, - "role": role - } - }) + self.clients.update_one( + {"name": client_data["name"]}, + {"$set": {"status": status, "role": role}}, + ) diff --git a/fedn/fedn/network/statestore/statestorebase.py b/fedn/fedn/network/statestore/statestorebase.py index 75e117731..f41e3c025 100644 --- a/fedn/fedn/network/statestore/statestorebase.py +++ b/fedn/fedn/network/statestore/statestorebase.py @@ -11,37 +11,42 @@ def __init__(self): @abstractmethod def state(self): - """ - + """ Return the current state of the statestore. """ pass @abstractmethod def transition(self, state): - """ + """ Transition the statestore to a new state. - :param state: + :param state: The new state. + :type state: str """ pass @abstractmethod - def set_latest(self, model_id): - """ + def set_latest_model(self, model_id): + """ Set the latest model id in the statestore. - :param model_id: + :param model_id: The model id. + :type model_id: str """ pass @abstractmethod - def get_latest(self): - """ + def get_latest_model(self): + """ Get the latest model id from the statestore. + :return: The model object. + :rtype: ObjectId """ pass @abstractmethod def is_inited(self): - """ + """ Check if the statestore is initialized. + :return: True if initialized, else False. + :rtype: bool """ pass diff --git a/fedn/fedn/tests/test_reducer_service.py b/fedn/fedn/tests/test_reducer_service.py index fc5ca8d9b..22e9d54af 100644 --- a/fedn/fedn/tests/test_reducer_service.py +++ b/fedn/fedn/tests/test_reducer_service.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from fedn.clients.reducer.restservice import ReducerRestService +from fedn.network.dashboard.restservice import ReducerRestService class TestInit(unittest.TestCase): @@ -78,15 +78,15 @@ def test_check_compute_package(self): def test_check_initial_model(self): - self.restservice.control.get_latest_model.return_value = 'model-uid' + self.restservice.statestore.get_latest_model.return_value = 'model-uid' retval = self.restservice.check_initial_model() self.assertTrue(retval) - self.restservice.control.get_latest_model.return_value = None + self.restservice.statestore.get_latest_model.return_value = None retval = self.restservice.check_initial_model() self.assertFalse(retval) - self.restservice.control.get_latest_model.return_value = '' + self.restservice.statestore.get_latest_model.return_value = '' retval = self.restservice.check_initial_model() self.assertFalse(retval) diff --git a/fedn/fedn/utils/__init__.py b/fedn/fedn/utils/__init__.py index e69de29bb..dc63e07e9 100644 --- a/fedn/fedn/utils/__init__.py +++ b/fedn/fedn/utils/__init__.py @@ -0,0 +1,3 @@ +""" The utils package is responsible for providing utility functions for the FEDn framework. Such as logging, checksums and other model helper functions to aggregate models. +THe helper functions is there to support aggregating various models from different ML frameworks, such as Tensorflow, PyTorch and Keras.""" +# flake8: noqa diff --git a/fedn/fedn/utils/checksum.py b/fedn/fedn/utils/checksum.py index 99fb7f840..3c7bbd3ec 100644 --- a/fedn/fedn/utils/checksum.py +++ b/fedn/fedn/utils/checksum.py @@ -2,10 +2,12 @@ def sha(fname): - """ + """ Calculate the sha256 checksum of a file. Used for computing checksums of compute packages. - :param fname: - :return: + :param fname: The file path. + :type fname: str + :return: The sha256 checksum. + :rtype: :py:class:`hashlib.sha256` """ hash = hashlib.sha256() with open(fname, "rb") as f: diff --git a/fedn/fedn/utils/dispatcher.py b/fedn/fedn/utils/dispatcher.py index 63743e6a6..3fe0a3fc1 100644 --- a/fedn/fedn/utils/dispatcher.py +++ b/fedn/fedn/utils/dispatcher.py @@ -6,18 +6,25 @@ class Dispatcher: - """ + """ Dispatcher class for compute packages. + :param config: The configuration. + :type config: dict + :param dir: The directory to dispatch to. + :type dir: str """ def __init__(self, config, dir): + """ Initialize the dispatcher.""" self.config = config self.project_dir = dir def run_cmd(self, cmd_type): - """ + """ Run a command. - :param cmd_type: + :param cmd_type: The command type. + :type cmd_type: str + :return: """ try: cmdsandargs = cmd_type.split(' ') diff --git a/fedn/fedn/utils/logger.py b/fedn/fedn/utils/logger.py index 08c2c94e1..563012996 100644 --- a/fedn/fedn/utils/logger.py +++ b/fedn/fedn/utils/logger.py @@ -3,11 +3,18 @@ class Logger: - """ + """ Logger class for Fedn. + :param log_level: The log level. + :type log_level: int + :param to_file: The name of the file to log to. + :type to_file: str + :param file_path: The path to the log file. + :type file_path: str """ def __init__(self, log_level=logging.DEBUG, to_file='', file_path=os.getcwd()): + """ Initialize the logger.""" root = logging.getLogger() root.setLevel(log_level) diff --git a/fedn/fedn/utils/plugins/__init__.py b/fedn/fedn/utils/plugins/__init__.py index e69de29bb..162a2d351 100644 --- a/fedn/fedn/utils/plugins/__init__.py +++ b/fedn/fedn/utils/plugins/__init__.py @@ -0,0 +1,3 @@ +""" The plugins package is responsible for loading model helper functions supporting different ML frameworks. The :class:`fedn.utils.plugins.helperbase.HelperBase` is +an abstract class which user can implement their own helper functions to support different ML frameworks. """ +# flake8: noqa diff --git a/fedn/fedn/utils/plugins/numpyarrayhelper.py b/fedn/fedn/utils/plugins/numpyarrayhelper.py index 9789bf541..21bf979b8 100644 --- a/fedn/fedn/utils/plugins/numpyarrayhelper.py +++ b/fedn/fedn/utils/plugins/numpyarrayhelper.py @@ -9,15 +9,28 @@ class Helper(HelperBase): """ FEDn helper class for numpy arrays. """ def increment_average(self, model, model_next, n): - """ Update an incremental average. """ + """ Update an incremental average. + + :param model: Current model weights. + :type model: numpy array. + :param model_next: New model weights. + :type model_next: numpy array. + :param n: Number of examples in new model. + :type n: int + :return: Incremental weighted average of model weights. + :rtype: :class:`numpy.array` + """ return np.add(model, (model_next - model) / n) def save(self, model, path=None): """Serialize weights/parameters to file. - :param model: - :param path: - :return: + :param model: Weights/parameters in numpy array format. + :type model: numpy array. + :param path: Path to file. + :type path: str + :return: Path to file. + :rtype: str """ if not path: _, path = tempfile.mkstemp() @@ -27,8 +40,10 @@ def save(self, model, path=None): def load(self, path): """Load weights/parameters from file or filelike. - :param path: - :return: + :param path: Path to file. + :type path: str + :return: Weights/parameters in numpy array format. + :rtype: :class:`numpy.array` """ model = np.loadtxt(path) return model diff --git a/fedn/fedn/utils/plugins/pytorchhelper.py b/fedn/fedn/utils/plugins/pytorchhelper.py index d1ce79717..17b01200c 100644 --- a/fedn/fedn/utils/plugins/pytorchhelper.py +++ b/fedn/fedn/utils/plugins/pytorchhelper.py @@ -54,6 +54,7 @@ def load(self, path): :param path: file path, filehandle, filelike. :type path: str :return: Weights of model with keys from torch state_dict. + :rtype: OrderedDict """ a = np.load(path) weights_np = OrderedDict() diff --git a/fedn/fedn/utils/process.py b/fedn/fedn/utils/process.py index 5a005e0cd..bd31f9441 100644 --- a/fedn/fedn/utils/process.py +++ b/fedn/fedn/utils/process.py @@ -5,18 +5,22 @@ def run_process(args, cwd): - """ + """ Run a process and log the output. - :param args: - :param cwd: + :param args: The arguments to the process. + :type args: list + :param cwd: The current working directory. + :type cwd: str + :return: """ status = subprocess.Popen( args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) # print(status) def check_io(): - """ + """ Check stdout/stderr of the child process. + :return: """ while True: output = status.stdout.readline().decode() diff --git a/fedn/setup.py b/fedn/setup.py index b6346388f..1dbdb951f 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -2,10 +2,8 @@ setup( name='fedn', - version='0.5.0-dev', + version='0.5.0', description="""Scaleout Federated Learning""", - long_description=open('README.md').read(), - long_description_content_type="text/markdown", author='Scaleout Systems AB', author_email='contact@scaleoutsystems.com', url='https://www.scaleoutsystems.com',