From e4731f72e74b6e7bd4e82d56915fedb4b379b721 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Tue, 20 Jun 2023 12:10:10 +0200 Subject: [PATCH 1/8] Hotfix/Issue#471 | Update README to use master tag for client (#472) --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 42503985a..2afb60ebc 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ To connect a client that uses the data partition 'data/clients/1/mnist.pt': -v $PWD/data/clients/1:/var/data \ -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ --network=fedn_default \ - ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch run client -in client.yaml --name client1 + ghcr.io/scaleoutsystems/fedn/fedn:master-mnist-pytorch run client -in client.yaml --name client1 You are now ready to start training the model at http://localhost:8090/control. From 78e534622c49b683190bf93aebc452337f530f75 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Thu, 29 Jun 2023 15:30:58 +0200 Subject: [PATCH 2/8] add init in controller --- fedn/fedn/network/controller/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 fedn/fedn/network/controller/__init__.py diff --git a/fedn/fedn/network/controller/__init__.py b/fedn/fedn/network/controller/__init__.py new file mode 100644 index 000000000..e69de29bb From 0209851e85373ccc05a447cd29edf024b0b22419 Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Wed, 2 Aug 2023 15:48:54 +0200 Subject: [PATCH 3/8] Feature/SK-505 | Flush model update queues at new session + Buffer size config (#476) * model update queues gets flushed in the beginning of a new session * Changed confusing log message * buffer_size now configurable, solves sk-520 * exclude isort for protobuf files * Deleted commented code * Added response status message --------- Co-authored-by: Andreas Hellander Co-authored-by: Fredrik Wrede --- .github/workflows/code-checks.yaml | 6 +- fedn/fedn/common/net/grpc/fedn.proto | 6 +- fedn/fedn/common/net/grpc/fedn_pb2.py | 337 +++--- fedn/fedn/common/net/grpc/fedn_pb2_grpc.py | 1029 +++++++++-------- fedn/fedn/network/combiner/interfaces.py | 19 +- fedn/fedn/network/combiner/round.py | 10 +- fedn/fedn/network/combiner/server.py | 58 +- fedn/fedn/network/controller/control.py | 13 +- fedn/fedn/network/dashboard/restservice.py | 5 +- .../network/dashboard/templates/index.html | 10 + 10 files changed, 791 insertions(+), 702 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index c1ec38548..c76c418f9 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -18,6 +18,8 @@ jobs: --skip .venv --skip .mnist-keras --skip .mnist-pytorch + --skip fedn_pb2.py + --skip fedn_pb2_grpc.py - name: check Python formatting run: > @@ -25,12 +27,14 @@ jobs: --exclude .venv --exclude .mnist-keras --exclude .mnist-pytorch + --exclude fedn_pb2.py + --exclude fedn_pb2_grpc.py . - name: run Python linter run: > .venv/bin/flake8 . - --exclude ".venv,.mnist-keras,.mnist-pytorch,fedn_pb2.py" + --exclude ".venv,.mnist-keras,.mnist-pytorch,fedn_pb2.py,fedn_pb2_grpc.py" - name: check for floating imports run: > diff --git a/fedn/fedn/common/net/grpc/fedn.proto b/fedn/fedn/common/net/grpc/fedn.proto index dca66fe20..ff0ee293c 100644 --- a/fedn/fedn/common/net/grpc/fedn.proto +++ b/fedn/fedn/common/net/grpc/fedn.proto @@ -4,7 +4,6 @@ package grpc; message Response { Client sender = 1; - //string client = 1; string response = 2; } @@ -19,7 +18,6 @@ enum StatusType { message Status { Client sender = 1; - //string client = 1; string status = 2; enum LogLevel { @@ -95,6 +93,7 @@ enum ModelStatus { IN_PROGRESS_OK = 2; FAILED = 3; } + message ModelRequest { Client sender = 1; Client receiver = 2; @@ -204,7 +203,8 @@ message ReportResponse { service Control { rpc Start(ControlRequest) returns (ControlResponse); rpc Stop(ControlRequest) returns (ControlResponse); - rpc Configure(ControlRequest) returns (ReportResponse); + rpc Configure(ControlRequest) returns (ReportResponse); + rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse); rpc Report(ControlRequest) returns (ReportResponse); } diff --git a/fedn/fedn/common/net/grpc/fedn_pb2.py b/fedn/fedn/common/net/grpc/fedn_pb2.py index 19cd119c6..fa4fbb16d 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2.py @@ -2,19 +2,20 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: fedn/common/net/grpc/fedn.proto """Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import enum_type_wrapper - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x66\x65\x64n/common/net/grpc/fedn.proto\x12\x04grpc\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\x8c\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.grpc.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.grpc.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xab\x01\n\x12ModelUpdateRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xaf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xc5\x01\n\x16ModelValidationRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x14\n\x0cis_inference\x18\x08 \x01(\x08\"\xa8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.grpc.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.grpc.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"R\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\r.grpc.Channel\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.grpc.Client\"0\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.grpc.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.grpc.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"R\n\x0eReportResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.grpc.ConnectionStatus*\x84\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\r\n\tINFERENCE\x10\x05*\x86\x01\n\x07\x43hannel\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x19\n\x15MODEL_UPDATE_REQUESTS\x10\x01\x12\x11\n\rMODEL_UPDATES\x10\x02\x12\x1d\n\x19MODEL_VALIDATION_REQUESTS\x10\x03\x12\x15\n\x11MODEL_VALIDATIONS\x10\x04\x12\n\n\x06STATUS\x10\x05*F\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse0\x01\x32\xe3\x01\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x33\n\x04Stop\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x37\n\tConfigure\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse\x12\x34\n\x06Report\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.grpc.GetGlobalModelRequest\x1a\x1c.grpc.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x0c.grpc.Status0\x01\x12*\n\nSendStatus\x12\x0c.grpc.Status\x1a\x0e.grpc.Response\x12?\n\x11ListActiveClients\x12\x18.grpc.ListClientsRequest\x1a\x10.grpc.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.grpc.ConnectionRequest\x1a\x18.grpc.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.grpc.Heartbeat\x1a\x0e.grpc.Response\x12\x37\n\x0eReassignClient\x12\x15.grpc.ReassignRequest\x1a\x0e.grpc.Response\x12\x39\n\x0fReconnectClient\x12\x16.grpc.ReconnectRequest\x1a\x0e.grpc.Response2\xda\x04\n\x08\x43ombiner\x12T\n\x18ModelUpdateRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x18.grpc.ModelUpdateRequest0\x01\x12\x46\n\x11ModelUpdateStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x11.grpc.ModelUpdate0\x01\x12\\\n\x1cModelValidationRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x1c.grpc.ModelValidationRequest0\x01\x12N\n\x15ModelValidationStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x15.grpc.ModelValidation0\x01\x12\x42\n\x16SendModelUpdateRequest\x12\x18.grpc.ModelUpdateRequest\x1a\x0e.grpc.Response\x12\x34\n\x0fSendModelUpdate\x12\x11.grpc.ModelUpdate\x1a\x0e.grpc.Response\x12J\n\x1aSendModelValidationRequest\x12\x1c.grpc.ModelValidationRequest\x1a\x0e.grpc.Response\x12<\n\x13SendModelValidation\x12\x15.grpc.ModelValidation\x1a\x0e.grpc.Responseb\x06proto3') # noqa: E501 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1f\x66\x65\x64n/common/net/grpc/fedn.proto\x12\x04grpc\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\x8c\x02\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12(\n\tlog_level\x18\x03 \x01(\x0e\x32\x15.grpc.Status.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.grpc.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\"B\n\x08LogLevel\x12\x08\n\x04INFO\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x0b\n\x07WARNING\x10\x02\x12\t\n\x05\x45RROR\x10\x03\x12\t\n\x05\x41UDIT\x10\x04\"\xab\x01\n\x12ModelUpdateRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xaf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xc5\x01\n\x16ModelValidationRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x14\n\x0cis_inference\x18\x08 \x01(\x08\"\xa8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.grpc.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.grpc.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\")\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"R\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\r.grpc.Channel\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.grpc.Client\"0\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.grpc.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.grpc.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.grpc.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"R\n\x0eReportResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.grpc.Client\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.grpc.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.grpc.ConnectionStatus*\x84\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\r\n\tINFERENCE\x10\x05*\x86\x01\n\x07\x43hannel\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x19\n\x15MODEL_UPDATE_REQUESTS\x10\x01\x12\x11\n\rMODEL_UPDATES\x10\x02\x12\x1d\n\x19MODEL_VALIDATION_REQUESTS\x10\x03\x12\x15\n\x11MODEL_VALIDATIONS\x10\x04\x12\n\n\x06STATUS\x10\x05*F\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03*8\n\x04Role\x12\n\n\x06WORKER\x10\x00\x12\x0c\n\x08\x43OMBINER\x10\x01\x12\x0b\n\x07REDUCER\x10\x02\x12\t\n\x05OTHER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.grpc.ModelRequest\x1a\x13.grpc.ModelResponse0\x01\x32\xa9\x02\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x33\n\x04Stop\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x37\n\tConfigure\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.grpc.ControlRequest\x1a\x15.grpc.ControlResponse\x12\x34\n\x06Report\x12\x14.grpc.ControlRequest\x1a\x14.grpc.ReportResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.grpc.GetGlobalModelRequest\x1a\x1c.grpc.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x0c.grpc.Status0\x01\x12*\n\nSendStatus\x12\x0c.grpc.Status\x1a\x0e.grpc.Response\x12?\n\x11ListActiveClients\x12\x18.grpc.ListClientsRequest\x1a\x10.grpc.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.grpc.ConnectionRequest\x1a\x18.grpc.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.grpc.Heartbeat\x1a\x0e.grpc.Response\x12\x37\n\x0eReassignClient\x12\x15.grpc.ReassignRequest\x1a\x0e.grpc.Response\x12\x39\n\x0fReconnectClient\x12\x16.grpc.ReconnectRequest\x1a\x0e.grpc.Response2\xda\x04\n\x08\x43ombiner\x12T\n\x18ModelUpdateRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x18.grpc.ModelUpdateRequest0\x01\x12\x46\n\x11ModelUpdateStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x11.grpc.ModelUpdate0\x01\x12\\\n\x1cModelValidationRequestStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x1c.grpc.ModelValidationRequest0\x01\x12N\n\x15ModelValidationStream\x12\x1c.grpc.ClientAvailableMessage\x1a\x15.grpc.ModelValidation0\x01\x12\x42\n\x16SendModelUpdateRequest\x12\x18.grpc.ModelUpdateRequest\x1a\x0e.grpc.Response\x12\x34\n\x0fSendModelUpdate\x12\x11.grpc.ModelUpdate\x1a\x0e.grpc.Response\x12J\n\x1aSendModelValidationRequest\x12\x1c.grpc.ModelValidationRequest\x1a\x0e.grpc.Response\x12<\n\x13SendModelValidation\x12\x15.grpc.ModelValidation\x1a\x0e.grpc.Responseb\x06proto3') _STATUSTYPE = DESCRIPTOR.enum_types_by_name['StatusType'] StatusType = enum_type_wrapper.EnumTypeWrapper(_STATUSTYPE) @@ -84,164 +85,164 @@ _CONNECTIONRESPONSE = DESCRIPTOR.message_types_by_name['ConnectionResponse'] _STATUS_LOGLEVEL = _STATUS.enum_types_by_name['LogLevel'] Response = _reflection.GeneratedProtocolMessageType('Response', (_message.Message,), { - 'DESCRIPTOR': _RESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Response) -}) + 'DESCRIPTOR' : _RESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Response) + }) _sym_db.RegisterMessage(Response) Status = _reflection.GeneratedProtocolMessageType('Status', (_message.Message,), { - 'DESCRIPTOR': _STATUS, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Status) -}) + 'DESCRIPTOR' : _STATUS, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Status) + }) _sym_db.RegisterMessage(Status) ModelUpdateRequest = _reflection.GeneratedProtocolMessageType('ModelUpdateRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELUPDATEREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelUpdateRequest) -}) + 'DESCRIPTOR' : _MODELUPDATEREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelUpdateRequest) + }) _sym_db.RegisterMessage(ModelUpdateRequest) ModelUpdate = _reflection.GeneratedProtocolMessageType('ModelUpdate', (_message.Message,), { - 'DESCRIPTOR': _MODELUPDATE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelUpdate) -}) + 'DESCRIPTOR' : _MODELUPDATE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelUpdate) + }) _sym_db.RegisterMessage(ModelUpdate) ModelValidationRequest = _reflection.GeneratedProtocolMessageType('ModelValidationRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELVALIDATIONREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelValidationRequest) -}) + 'DESCRIPTOR' : _MODELVALIDATIONREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelValidationRequest) + }) _sym_db.RegisterMessage(ModelValidationRequest) ModelValidation = _reflection.GeneratedProtocolMessageType('ModelValidation', (_message.Message,), { - 'DESCRIPTOR': _MODELVALIDATION, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelValidation) -}) + 'DESCRIPTOR' : _MODELVALIDATION, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelValidation) + }) _sym_db.RegisterMessage(ModelValidation) ModelRequest = _reflection.GeneratedProtocolMessageType('ModelRequest', (_message.Message,), { - 'DESCRIPTOR': _MODELREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelRequest) -}) + 'DESCRIPTOR' : _MODELREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelRequest) + }) _sym_db.RegisterMessage(ModelRequest) ModelResponse = _reflection.GeneratedProtocolMessageType('ModelResponse', (_message.Message,), { - 'DESCRIPTOR': _MODELRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ModelResponse) -}) + 'DESCRIPTOR' : _MODELRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ModelResponse) + }) _sym_db.RegisterMessage(ModelResponse) GetGlobalModelRequest = _reflection.GeneratedProtocolMessageType('GetGlobalModelRequest', (_message.Message,), { - 'DESCRIPTOR': _GETGLOBALMODELREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelRequest) -}) + 'DESCRIPTOR' : _GETGLOBALMODELREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelRequest) + }) _sym_db.RegisterMessage(GetGlobalModelRequest) GetGlobalModelResponse = _reflection.GeneratedProtocolMessageType('GetGlobalModelResponse', (_message.Message,), { - 'DESCRIPTOR': _GETGLOBALMODELRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelResponse) -}) + 'DESCRIPTOR' : _GETGLOBALMODELRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.GetGlobalModelResponse) + }) _sym_db.RegisterMessage(GetGlobalModelResponse) Heartbeat = _reflection.GeneratedProtocolMessageType('Heartbeat', (_message.Message,), { - 'DESCRIPTOR': _HEARTBEAT, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Heartbeat) -}) + 'DESCRIPTOR' : _HEARTBEAT, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Heartbeat) + }) _sym_db.RegisterMessage(Heartbeat) ClientAvailableMessage = _reflection.GeneratedProtocolMessageType('ClientAvailableMessage', (_message.Message,), { - 'DESCRIPTOR': _CLIENTAVAILABLEMESSAGE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ClientAvailableMessage) -}) + 'DESCRIPTOR' : _CLIENTAVAILABLEMESSAGE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ClientAvailableMessage) + }) _sym_db.RegisterMessage(ClientAvailableMessage) ListClientsRequest = _reflection.GeneratedProtocolMessageType('ListClientsRequest', (_message.Message,), { - 'DESCRIPTOR': _LISTCLIENTSREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ListClientsRequest) -}) + 'DESCRIPTOR' : _LISTCLIENTSREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ListClientsRequest) + }) _sym_db.RegisterMessage(ListClientsRequest) ClientList = _reflection.GeneratedProtocolMessageType('ClientList', (_message.Message,), { - 'DESCRIPTOR': _CLIENTLIST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ClientList) -}) + 'DESCRIPTOR' : _CLIENTLIST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ClientList) + }) _sym_db.RegisterMessage(ClientList) Client = _reflection.GeneratedProtocolMessageType('Client', (_message.Message,), { - 'DESCRIPTOR': _CLIENT, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Client) -}) + 'DESCRIPTOR' : _CLIENT, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Client) + }) _sym_db.RegisterMessage(Client) ReassignRequest = _reflection.GeneratedProtocolMessageType('ReassignRequest', (_message.Message,), { - 'DESCRIPTOR': _REASSIGNREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReassignRequest) -}) + 'DESCRIPTOR' : _REASSIGNREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReassignRequest) + }) _sym_db.RegisterMessage(ReassignRequest) ReconnectRequest = _reflection.GeneratedProtocolMessageType('ReconnectRequest', (_message.Message,), { - 'DESCRIPTOR': _RECONNECTREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReconnectRequest) -}) + 'DESCRIPTOR' : _RECONNECTREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReconnectRequest) + }) _sym_db.RegisterMessage(ReconnectRequest) Parameter = _reflection.GeneratedProtocolMessageType('Parameter', (_message.Message,), { - 'DESCRIPTOR': _PARAMETER, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.Parameter) -}) + 'DESCRIPTOR' : _PARAMETER, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.Parameter) + }) _sym_db.RegisterMessage(Parameter) ControlRequest = _reflection.GeneratedProtocolMessageType('ControlRequest', (_message.Message,), { - 'DESCRIPTOR': _CONTROLREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ControlRequest) -}) + 'DESCRIPTOR' : _CONTROLREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ControlRequest) + }) _sym_db.RegisterMessage(ControlRequest) ControlResponse = _reflection.GeneratedProtocolMessageType('ControlResponse', (_message.Message,), { - 'DESCRIPTOR': _CONTROLRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ControlResponse) -}) + 'DESCRIPTOR' : _CONTROLRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ControlResponse) + }) _sym_db.RegisterMessage(ControlResponse) ReportResponse = _reflection.GeneratedProtocolMessageType('ReportResponse', (_message.Message,), { - 'DESCRIPTOR': _REPORTRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ReportResponse) -}) + 'DESCRIPTOR' : _REPORTRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ReportResponse) + }) _sym_db.RegisterMessage(ReportResponse) ConnectionRequest = _reflection.GeneratedProtocolMessageType('ConnectionRequest', (_message.Message,), { - 'DESCRIPTOR': _CONNECTIONREQUEST, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ConnectionRequest) -}) + 'DESCRIPTOR' : _CONNECTIONREQUEST, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ConnectionRequest) + }) _sym_db.RegisterMessage(ConnectionRequest) ConnectionResponse = _reflection.GeneratedProtocolMessageType('ConnectionResponse', (_message.Message,), { - 'DESCRIPTOR': _CONNECTIONRESPONSE, - '__module__': 'fedn.common.net.grpc.fedn_pb2' - # @@protoc_insertion_point(class_scope:grpc.ConnectionResponse) -}) + 'DESCRIPTOR' : _CONNECTIONRESPONSE, + '__module__' : 'fedn.common.net.grpc.fedn_pb2' + # @@protoc_insertion_point(class_scope:grpc.ConnectionResponse) + }) _sym_db.RegisterMessage(ConnectionResponse) _MODELSERVICE = DESCRIPTOR.services_by_name['ModelService'] @@ -249,77 +250,77 @@ _REDUCER = DESCRIPTOR.services_by_name['Reducer'] _CONNECTOR = DESCRIPTOR.services_by_name['Connector'] _COMBINER = DESCRIPTOR.services_by_name['Combiner'] -if _descriptor._USE_C_DESCRIPTORS is False: - - DESCRIPTOR._options = None - _STATUSTYPE._serialized_start = 2412 - _STATUSTYPE._serialized_end = 2544 - _CHANNEL._serialized_start = 2547 - _CHANNEL._serialized_end = 2681 - _MODELSTATUS._serialized_start = 2683 - _MODELSTATUS._serialized_end = 2753 - _ROLE._serialized_start = 2755 - _ROLE._serialized_end = 2811 - _COMMAND._serialized_start = 2813 - _COMMAND._serialized_end = 2887 - _CONNECTIONSTATUS._serialized_start = 2889 - _CONNECTIONSTATUS._serialized_end = 2962 - _RESPONSE._serialized_start = 41 - _RESPONSE._serialized_end = 99 - _STATUS._serialized_start = 102 - _STATUS._serialized_end = 370 - _STATUS_LOGLEVEL._serialized_start = 304 - _STATUS_LOGLEVEL._serialized_end = 370 - _MODELUPDATEREQUEST._serialized_start = 373 - _MODELUPDATEREQUEST._serialized_end = 544 - _MODELUPDATE._serialized_start = 547 - _MODELUPDATE._serialized_end = 722 - _MODELVALIDATIONREQUEST._serialized_start = 725 - _MODELVALIDATIONREQUEST._serialized_end = 922 - _MODELVALIDATION._serialized_start = 925 - _MODELVALIDATION._serialized_end = 1093 - _MODELREQUEST._serialized_start = 1096 - _MODELREQUEST._serialized_end = 1233 - _MODELRESPONSE._serialized_start = 1235 - _MODELRESPONSE._serialized_end = 1328 - _GETGLOBALMODELREQUEST._serialized_start = 1330 - _GETGLOBALMODELREQUEST._serialized_end = 1415 - _GETGLOBALMODELRESPONSE._serialized_start = 1417 - _GETGLOBALMODELRESPONSE._serialized_end = 1521 - _HEARTBEAT._serialized_start = 1523 - _HEARTBEAT._serialized_end = 1564 - _CLIENTAVAILABLEMESSAGE._serialized_start = 1566 - _CLIENTAVAILABLEMESSAGE._serialized_end = 1653 - _LISTCLIENTSREQUEST._serialized_start = 1655 - _LISTCLIENTSREQUEST._serialized_end = 1737 - _CLIENTLIST._serialized_start = 1739 - _CLIENTLIST._serialized_end = 1781 - _CLIENT._serialized_start = 1783 - _CLIENT._serialized_end = 1831 - _REASSIGNREQUEST._serialized_start = 1833 - _REASSIGNREQUEST._serialized_end = 1942 - _RECONNECTREQUEST._serialized_start = 1944 - _RECONNECTREQUEST._serialized_end = 2043 - _PARAMETER._serialized_start = 2045 - _PARAMETER._serialized_end = 2084 - _CONTROLREQUEST._serialized_start = 2086 - _CONTROLREQUEST._serialized_end = 2170 - _CONTROLRESPONSE._serialized_start = 2172 - _CONTROLRESPONSE._serialized_end = 2242 - _REPORTRESPONSE._serialized_start = 2244 - _REPORTRESPONSE._serialized_end = 2326 - _CONNECTIONREQUEST._serialized_start = 2328 - _CONNECTIONREQUEST._serialized_end = 2347 - _CONNECTIONRESPONSE._serialized_start = 2349 - _CONNECTIONRESPONSE._serialized_end = 2409 - _MODELSERVICE._serialized_start = 2964 - _MODELSERVICE._serialized_end = 3086 - _CONTROL._serialized_start = 3089 - _CONTROL._serialized_end = 3316 - _REDUCER._serialized_start = 3318 - _REDUCER._serialized_end = 3404 - _CONNECTOR._serialized_start = 3407 - _CONNECTOR._serialized_end = 3834 - _COMBINER._serialized_start = 3837 - _COMBINER._serialized_end = 4439 +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _STATUSTYPE._serialized_start=2412 + _STATUSTYPE._serialized_end=2544 + _CHANNEL._serialized_start=2547 + _CHANNEL._serialized_end=2681 + _MODELSTATUS._serialized_start=2683 + _MODELSTATUS._serialized_end=2753 + _ROLE._serialized_start=2755 + _ROLE._serialized_end=2811 + _COMMAND._serialized_start=2813 + _COMMAND._serialized_end=2887 + _CONNECTIONSTATUS._serialized_start=2889 + _CONNECTIONSTATUS._serialized_end=2962 + _RESPONSE._serialized_start=41 + _RESPONSE._serialized_end=99 + _STATUS._serialized_start=102 + _STATUS._serialized_end=370 + _STATUS_LOGLEVEL._serialized_start=304 + _STATUS_LOGLEVEL._serialized_end=370 + _MODELUPDATEREQUEST._serialized_start=373 + _MODELUPDATEREQUEST._serialized_end=544 + _MODELUPDATE._serialized_start=547 + _MODELUPDATE._serialized_end=722 + _MODELVALIDATIONREQUEST._serialized_start=725 + _MODELVALIDATIONREQUEST._serialized_end=922 + _MODELVALIDATION._serialized_start=925 + _MODELVALIDATION._serialized_end=1093 + _MODELREQUEST._serialized_start=1096 + _MODELREQUEST._serialized_end=1233 + _MODELRESPONSE._serialized_start=1235 + _MODELRESPONSE._serialized_end=1328 + _GETGLOBALMODELREQUEST._serialized_start=1330 + _GETGLOBALMODELREQUEST._serialized_end=1415 + _GETGLOBALMODELRESPONSE._serialized_start=1417 + _GETGLOBALMODELRESPONSE._serialized_end=1521 + _HEARTBEAT._serialized_start=1523 + _HEARTBEAT._serialized_end=1564 + _CLIENTAVAILABLEMESSAGE._serialized_start=1566 + _CLIENTAVAILABLEMESSAGE._serialized_end=1653 + _LISTCLIENTSREQUEST._serialized_start=1655 + _LISTCLIENTSREQUEST._serialized_end=1737 + _CLIENTLIST._serialized_start=1739 + _CLIENTLIST._serialized_end=1781 + _CLIENT._serialized_start=1783 + _CLIENT._serialized_end=1831 + _REASSIGNREQUEST._serialized_start=1833 + _REASSIGNREQUEST._serialized_end=1942 + _RECONNECTREQUEST._serialized_start=1944 + _RECONNECTREQUEST._serialized_end=2043 + _PARAMETER._serialized_start=2045 + _PARAMETER._serialized_end=2084 + _CONTROLREQUEST._serialized_start=2086 + _CONTROLREQUEST._serialized_end=2170 + _CONTROLRESPONSE._serialized_start=2172 + _CONTROLRESPONSE._serialized_end=2242 + _REPORTRESPONSE._serialized_start=2244 + _REPORTRESPONSE._serialized_end=2326 + _CONNECTIONREQUEST._serialized_start=2328 + _CONNECTIONREQUEST._serialized_end=2347 + _CONNECTIONRESPONSE._serialized_start=2349 + _CONNECTIONRESPONSE._serialized_end=2409 + _MODELSERVICE._serialized_start=2964 + _MODELSERVICE._serialized_end=3086 + _CONTROL._serialized_start=3089 + _CONTROL._serialized_end=3386 + _REDUCER._serialized_start=3388 + _REDUCER._serialized_end=3474 + _CONNECTOR._serialized_start=3477 + _CONNECTOR._serialized_end=3904 + _COMBINER._serialized_start=3907 + _COMBINER._serialized_end=4509 # @@protoc_insertion_point(module_scope) diff --git a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py index 7cf017b34..9590e2b5c 100644 --- a/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py +++ b/fedn/fedn/common/net/grpc/fedn_pb2_grpc.py @@ -2,8 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from fedn.common.net.grpc import \ - fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 +from fedn.common.net.grpc import fedn_pb2 as fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2 class ModelServiceStub(object): @@ -16,15 +15,15 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Upload = channel.stream_unary( - '/grpc.ModelService/Upload', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - ) + '/grpc.ModelService/Upload', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + ) self.Download = channel.unary_stream( - '/grpc.ModelService/Download', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - ) + '/grpc.ModelService/Download', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + ) class ModelServiceServicer(object): @@ -45,60 +44,59 @@ def Download(self, request, context): def add_ModelServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'Upload': grpc.stream_unary_rpc_method_handler( - servicer.Upload, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, - ), - 'Download': grpc.unary_stream_rpc_method_handler( - servicer.Download, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, - ), + 'Upload': grpc.stream_unary_rpc_method_handler( + servicer.Upload, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + ), + 'Download': grpc.unary_stream_rpc_method_handler( + servicer.Download, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.ModelService', rpc_method_handlers) + 'grpc.ModelService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class ModelService(object): """Missing associated documentation comment in .proto file.""" @staticmethod def Upload(request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.stream_unary(request_iterator, target, '/grpc.ModelService/Upload', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Download(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.ModelService/Download', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ControlStub(object): @@ -111,25 +109,30 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Start = channel.unary_unary( - '/grpc.Control/Start', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - ) + '/grpc.Control/Start', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Stop = channel.unary_unary( - '/grpc.Control/Stop', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - ) + '/grpc.Control/Stop', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Configure = channel.unary_unary( - '/grpc.Control/Configure', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - ) + '/grpc.Control/Configure', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + ) + self.FlushAggregationQueue = channel.unary_unary( + '/grpc.Control/FlushAggregationQueue', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + ) self.Report = channel.unary_unary( - '/grpc.Control/Report', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - ) + '/grpc.Control/Report', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + ) class ControlServicer(object): @@ -153,6 +156,12 @@ def Configure(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def FlushAggregationQueue(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Report(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -162,104 +171,125 @@ def Report(self, request, context): def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { - 'Start': grpc.unary_unary_rpc_method_handler( - servicer.Start, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, - ), - 'Stop': grpc.unary_unary_rpc_method_handler( - servicer.Stop, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, - ), - 'Configure': grpc.unary_unary_rpc_method_handler( - servicer.Configure, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, - ), - 'Report': grpc.unary_unary_rpc_method_handler( - servicer.Report, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, - ), + 'Start': grpc.unary_unary_rpc_method_handler( + servicer.Start, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Stop': grpc.unary_unary_rpc_method_handler( + servicer.Stop, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Configure': grpc.unary_unary_rpc_method_handler( + servicer.Configure, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, + ), + 'FlushAggregationQueue': grpc.unary_unary_rpc_method_handler( + servicer.FlushAggregationQueue, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), + 'Report': grpc.unary_unary_rpc_method_handler( + servicer.Report, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Control', rpc_method_handlers) + 'grpc.Control', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Control(object): """Missing associated documentation comment in .proto file.""" @staticmethod def Start(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Start', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Stop(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Stop', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Configure(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Configure', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def FlushAggregationQueue(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/grpc.Control/FlushAggregationQueue', + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Report(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Control/Report', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReportResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ReducerStub(object): @@ -272,10 +302,10 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.GetGlobalModel = channel.unary_unary( - '/grpc.Reducer/GetGlobalModel', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, - ) + '/grpc.Reducer/GetGlobalModel', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + ) class ReducerServicer(object): @@ -290,38 +320,37 @@ def GetGlobalModel(self, request, context): def add_ReducerServicer_to_server(servicer, server): rpc_method_handlers = { - 'GetGlobalModel': grpc.unary_unary_rpc_method_handler( - servicer.GetGlobalModel, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, - ), + 'GetGlobalModel': grpc.unary_unary_rpc_method_handler( + servicer.GetGlobalModel, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Reducer', rpc_method_handlers) + 'grpc.Reducer', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Reducer(object): """Missing associated documentation comment in .proto file.""" @staticmethod def GetGlobalModel(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Reducer/GetGlobalModel', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.GetGlobalModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ConnectorStub(object): @@ -334,40 +363,40 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.AllianceStatusStream = channel.unary_stream( - '/grpc.Connector/AllianceStatusStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - ) + '/grpc.Connector/AllianceStatusStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + ) self.SendStatus = channel.unary_unary( - '/grpc.Connector/SendStatus', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/SendStatus', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ListActiveClients = channel.unary_unary( - '/grpc.Connector/ListActiveClients', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, - ) + '/grpc.Connector/ListActiveClients', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, + ) self.AcceptingClients = channel.unary_unary( - '/grpc.Connector/AcceptingClients', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, - ) + '/grpc.Connector/AcceptingClients', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + ) self.SendHeartbeat = channel.unary_unary( - '/grpc.Connector/SendHeartbeat', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/SendHeartbeat', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ReassignClient = channel.unary_unary( - '/grpc.Connector/ReassignClient', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/ReassignClient', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.ReconnectClient = channel.unary_unary( - '/grpc.Connector/ReconnectClient', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Connector/ReconnectClient', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) class ConnectorServicer(object): @@ -423,170 +452,169 @@ def ReconnectClient(self, request, context): def add_ConnectorServicer_to_server(servicer, server): rpc_method_handlers = { - 'AllianceStatusStream': grpc.unary_stream_rpc_method_handler( - servicer.AllianceStatusStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - ), - 'SendStatus': grpc.unary_unary_rpc_method_handler( - servicer.SendStatus, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ListActiveClients': grpc.unary_unary_rpc_method_handler( - servicer.ListActiveClients, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, - ), - 'AcceptingClients': grpc.unary_unary_rpc_method_handler( - servicer.AcceptingClients, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, - ), - 'SendHeartbeat': grpc.unary_unary_rpc_method_handler( - servicer.SendHeartbeat, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ReassignClient': grpc.unary_unary_rpc_method_handler( - servicer.ReassignClient, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'ReconnectClient': grpc.unary_unary_rpc_method_handler( - servicer.ReconnectClient, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), + 'AllianceStatusStream': grpc.unary_stream_rpc_method_handler( + servicer.AllianceStatusStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + ), + 'SendStatus': grpc.unary_unary_rpc_method_handler( + servicer.SendStatus, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ListActiveClients': grpc.unary_unary_rpc_method_handler( + servicer.ListActiveClients, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.SerializeToString, + ), + 'AcceptingClients': grpc.unary_unary_rpc_method_handler( + servicer.AcceptingClients, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.SerializeToString, + ), + 'SendHeartbeat': grpc.unary_unary_rpc_method_handler( + servicer.SendHeartbeat, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ReassignClient': grpc.unary_unary_rpc_method_handler( + servicer.ReassignClient, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'ReconnectClient': grpc.unary_unary_rpc_method_handler( + servicer.ReconnectClient, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Connector', rpc_method_handlers) + 'grpc.Connector', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Connector(object): """Missing associated documentation comment in .proto file.""" @staticmethod def AllianceStatusStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Connector/AllianceStatusStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendStatus(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/SendStatus', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Status.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ListActiveClients(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ListActiveClients', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ListClientsRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientList.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def AcceptingClients(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/AcceptingClients', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ConnectionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendHeartbeat(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/SendHeartbeat', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Heartbeat.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ReassignClient(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ReassignClient', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReassignRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ReconnectClient(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Connector/ReconnectClient', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ReconnectRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class CombinerStub(object): @@ -599,45 +627,45 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.ModelUpdateRequestStream = channel.unary_stream( - '/grpc.Combiner/ModelUpdateRequestStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - ) + '/grpc.Combiner/ModelUpdateRequestStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + ) self.ModelUpdateStream = channel.unary_stream( - '/grpc.Combiner/ModelUpdateStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - ) + '/grpc.Combiner/ModelUpdateStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + ) self.ModelValidationRequestStream = channel.unary_stream( - '/grpc.Combiner/ModelValidationRequestStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - ) + '/grpc.Combiner/ModelValidationRequestStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + ) self.ModelValidationStream = channel.unary_stream( - '/grpc.Combiner/ModelValidationStream', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - ) + '/grpc.Combiner/ModelValidationStream', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + ) self.SendModelUpdateRequest = channel.unary_unary( - '/grpc.Combiner/SendModelUpdateRequest', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelUpdateRequest', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelUpdate = channel.unary_unary( - '/grpc.Combiner/SendModelUpdate', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelUpdate', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelValidationRequest = channel.unary_unary( - '/grpc.Combiner/SendModelValidationRequest', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelValidationRequest', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) self.SendModelValidation = channel.unary_unary( - '/grpc.Combiner/SendModelValidation', - request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - ) + '/grpc.Combiner/SendModelValidation', + request_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + response_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + ) class CombinerServicer(object): @@ -695,189 +723,188 @@ def SendModelValidation(self, request, context): def add_CombinerServicer_to_server(servicer, server): rpc_method_handlers = { - 'ModelUpdateRequestStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelUpdateRequestStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - ), - 'ModelUpdateStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelUpdateStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - ), - 'ModelValidationRequestStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelValidationRequestStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - ), - 'ModelValidationStream': grpc.unary_stream_rpc_method_handler( - servicer.ModelValidationStream, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - ), - 'SendModelUpdateRequest': grpc.unary_unary_rpc_method_handler( - servicer.SendModelUpdateRequest, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelUpdate': grpc.unary_unary_rpc_method_handler( - servicer.SendModelUpdate, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelValidationRequest': grpc.unary_unary_rpc_method_handler( - servicer.SendModelValidationRequest, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), - 'SendModelValidation': grpc.unary_unary_rpc_method_handler( - servicer.SendModelValidation, - request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, - ), + 'ModelUpdateRequestStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelUpdateRequestStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + ), + 'ModelUpdateStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelUpdateStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + ), + 'ModelValidationRequestStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelValidationRequestStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + ), + 'ModelValidationStream': grpc.unary_stream_rpc_method_handler( + servicer.ModelValidationStream, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + ), + 'SendModelUpdateRequest': grpc.unary_unary_rpc_method_handler( + servicer.SendModelUpdateRequest, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelUpdate': grpc.unary_unary_rpc_method_handler( + servicer.SendModelUpdate, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelValidationRequest': grpc.unary_unary_rpc_method_handler( + servicer.SendModelValidationRequest, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), + 'SendModelValidation': grpc.unary_unary_rpc_method_handler( + servicer.SendModelValidation, + request_deserializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + response_serializer=fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'grpc.Combiner', rpc_method_handlers) + 'grpc.Combiner', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. - + # This class is part of an EXPERIMENTAL API. class Combiner(object): """Missing associated documentation comment in .proto file.""" @staticmethod def ModelUpdateRequestStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelUpdateRequestStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelUpdateStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelUpdateStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelValidationRequestStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelValidationRequestStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ModelValidationStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_stream(request, target, '/grpc.Combiner/ModelValidationStream', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ClientAvailableMessage.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelUpdateRequest(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelUpdateRequest', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdateRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelUpdate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelUpdate', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelUpdate.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelValidationRequest(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelValidationRequest', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidationRequest.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendModelValidation(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/grpc.Combiner/SendModelValidation', - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, - fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.ModelValidation.SerializeToString, + fedn_dot_common_dot_net_dot_grpc_dot_fedn__pb2.Response.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/fedn/fedn/network/combiner/interfaces.py b/fedn/fedn/network/combiner/interfaces.py index 832c5fdfa..d70fc80d0 100644 --- a/fedn/fedn/network/combiner/interfaces.py +++ b/fedn/fedn/network/combiner/interfaces.py @@ -51,7 +51,7 @@ def get_channel(self): class CombinerInterface: - """ Interface for the Combiner (server). + """ Interface for the Combiner (aggregation server). Abstraction on top of the gRPC server servicer. """ @@ -220,6 +220,23 @@ def configure(self, config=None): else: raise + def flush_model_update_queue(self): + """ Reset the model update queue on the combiner. """ + + channel = Channel(self.address, self.port, + self.certificate).get_channel() + control = rpc.ControlStub(channel) + + request = fedn.ControlRequest() + + try: + control.FlushAggregationQueue(request) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + raise CombinerUnavailableError + else: + raise + def submit(self, config): """ Submit a compute plan to the combiner. diff --git a/fedn/fedn/network/combiner/round.py b/fedn/fedn/network/combiner/round.py index 6ae7a92cc..a19adc9db 100644 --- a/fedn/fedn/network/combiner/round.py +++ b/fedn/fedn/network/combiner/round.py @@ -118,7 +118,6 @@ def waitforit(self, config, buffer_size=100, polling_interval=0.1): """ time_window = float(config['round_timeout']) - # buffer_size = int(config['buffer_size']) tt = 0.0 while tt < time_window: @@ -150,8 +149,14 @@ def _training_round(self, config, clients): # Request model updates from all active clients. self.server.request_model_update(config, clients=clients) + # If buffer_size is -1 (default), the round terminates when/if all clients have completed. + if int(config['buffer_size']) == -1: + buffer_size = len(clients) + else: + buffer_size = int(config['buffer_size']) + # Wait / block until the round termination policy has been met. - self.waitforit(config, buffer_size=len(clients)) + self.waitforit(config, buffer_size=buffer_size) tic = time.time() model = None @@ -159,7 +164,6 @@ def _training_round(self, config, clients): try: helper = get_helper(config['helper_type']) - # print config delete_models_storage print("ROUNDCONTROL: Config delete_models_storage: {}".format(config['delete_models_storage']), flush=True) if config['delete_models_storage'] == 'True': delete_models = True diff --git a/fedn/fedn/network/combiner/server.py b/fedn/fedn/network/combiner/server.py index ff0134d6b..f5449a375 100644 --- a/fedn/fedn/network/combiner/server.py +++ b/fedn/fedn/network/combiner/server.py @@ -371,6 +371,19 @@ def __register_heartbeat(self, client): self.__join_client(client) self.clients[client.name]["lastseen"] = datetime.now() + def flush_model_update_queue(self): + """Clear the model update queue (aggregator). """ + + q = self.control.aggregator.model_updates + try: + with q.mutex: + q.queue.clear() + q.all_tasks_done.notify_all() + q.unfinished_tasks = 0 + return True + except Exception: + return False + ##################################################################################################################### # Control Service @@ -400,6 +413,9 @@ def Start(self, control: fedn.ControlRequest, context): return response + # RPCs related to remote configuration of the server, round controller, + # aggregator and their states. + def Configure(self, control: fedn.ControlRequest, context): """ Configure the Combiner. @@ -416,6 +432,29 @@ def Configure(self, control: fedn.ControlRequest, context): response = fedn.ControlResponse() return response + def FlushAggregationQueue(self, control: fedn.ControlRequest, context): + """ Flush the queue. + + :param control: the control request + :type control: :class:`fedn.common.net.grpc.fedn_pb2.ControlRequest` + :param context: the context (unused) + :type context: :class:`grpc._server._Context` + :return: the control response + :rtype: :class:`fedn.common.net.grpc.fedn_pb2.ControlResponse` + """ + + status = self.flush_model_update_queue() + + response = fedn.ControlResponse() + if status: + response.message = 'Success' + else: + response.message = 'Failed' + + return response + + ############################################################################## + def Stop(self, control: fedn.ControlRequest, context): """ TODO: Not yet implemented. @@ -494,25 +533,6 @@ def Report(self, control: fedn.ControlRequest, context): ##################################################################################################################### - def AllianceStatusStream(self, response, context): - """ A server stream RPC endpoint that emits status messages. - - :param response: the response - :type response: :class:`fedn.common.net.grpc.fedn_pb2.Response` - :param context: the context (unused) - :type context: :class:`grpc._server._Context`""" - status = fedn.Status( - status="Client {} connecting to AllianceStatusStream.".format(response.sender)) - status.log_level = fedn.Status.INFO - status.sender.name = self.id - status.sender.role = role_to_proto_role(self.role) - self._subscribe_client_to_queue(response.sender, fedn.Channel.STATUS) - q = self.__get_queue(response.sender, fedn.Channel.STATUS) - self._send_status(status) - - while True: - yield q.get() - def SendStatus(self, status: fedn.Status, context): """ A client stream RPC endpoint that accepts status messages. diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index 7929f2892..68e4623e1 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -77,18 +77,23 @@ def session(self, config): print("Controller already in INSTRUCTING state. A session is in progress.", flush=True) return + if not self.get_latest_model(): + print("No model in model chain, please provide a seed model!") + return + self._state = ReducerState.instructing # Must be called to set info in the db self.new_session(config) - if not self.get_latest_model(): - print("No model in model chain, please provide a seed model!") - self._state = ReducerState.monitoring last_round = int(self.get_latest_round_id()) + # Clear potential stragglers/old model updates at combiners + for combiner in self.network.get_combiners(): + combiner.flush_model_update_queue() + # Execute the rounds in this session for round in range(1, int(config['rounds'] + 1)): # Increment the round number @@ -179,7 +184,7 @@ def round(self, session_config, round_id): else: # Print every 10 seconds based on value of wait if wait % 10 == 0: - print("CONTROL: Round not found! Waiting...", flush=True) + print("CONTROL: Waiting for round to complete...", flush=True) if wait >= session_config['round_timeout']: print("CONTROL: Round timeout! Exiting round...", flush=True) break diff --git a/fedn/fedn/network/dashboard/restservice.py b/fedn/fedn/network/dashboard/restservice.py index b951fb83d..5782ff088 100644 --- a/fedn/fedn/network/dashboard/restservice.py +++ b/fedn/fedn/network/dashboard/restservice.py @@ -571,6 +571,7 @@ def control(): if request.method == 'POST': # Get session configuration round_timeout = float(request.form.get('timeout', 180)) + buffer_size = int(request.form.get('buffer_size', -1)) rounds = int(request.form.get('rounds', 1)) delete_models = request.form.get('delete_models', True) task = (request.form.get('task', '')) @@ -601,8 +602,8 @@ def control(): latest_model_id = self.control.get_latest_model() - config = {'round_timeout': round_timeout, 'model_id': latest_model_id, - 'rounds': rounds, 'delete_models_storage': delete_models, + config = {'round_timeout': round_timeout, 'buffer_size': buffer_size, + 'model_id': latest_model_id, 'rounds': rounds, 'delete_models_storage': delete_models, 'clients_required': clients_required, 'clients_requested': clients_requested, 'task': task, 'validate': validate, 'helper_type': helper_type} diff --git a/fedn/fedn/network/dashboard/templates/index.html b/fedn/fedn/network/dashboard/templates/index.html index a4182130d..4ac2d182b 100644 --- a/fedn/fedn/network/dashboard/templates/index.html +++ b/fedn/fedn/network/dashboard/templates/index.html @@ -305,6 +305,16 @@
New session
+ + +
+ +
+ +
From 6d0b7ff11836defd36d1c59a20bd643b77648465 Mon Sep 17 00:00:00 2001 From: Fredrik Wrede Date: Wed, 25 Oct 2023 16:22:00 +0200 Subject: [PATCH 4/8] update version --- README.rst | 2 +- fedn/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 42503985a..2afb60ebc 100644 --- a/README.rst +++ b/README.rst @@ -101,7 +101,7 @@ To connect a client that uses the data partition 'data/clients/1/mnist.pt': -v $PWD/data/clients/1:/var/data \ -e ENTRYPOINT_OPTS=--data_path=/var/data/mnist.pt \ --network=fedn_default \ - ghcr.io/scaleoutsystems/fedn/fedn:develop-mnist-pytorch run client -in client.yaml --name client1 + ghcr.io/scaleoutsystems/fedn/fedn:master-mnist-pytorch run client -in client.yaml --name client1 You are now ready to start training the model at http://localhost:8090/control. diff --git a/fedn/setup.py b/fedn/setup.py index b6346388f..93eea674d 100644 --- a/fedn/setup.py +++ b/fedn/setup.py @@ -2,7 +2,7 @@ setup( name='fedn', - version='0.5.0-dev', + version='0.5.0', description="""Scaleout Federated Learning""", long_description=open('README.md').read(), long_description_content_type="text/markdown", From 1cd03e359b5f67ad7ac647a3d2cd91b70acee7e3 Mon Sep 17 00:00:00 2001 From: Niklas Date: Wed, 1 Nov 2023 12:29:33 +0100 Subject: [PATCH 5/8] Feature/SK-553 | Add pagination option to REST-API (#482) Co-authored-by: Fredrik Wrede --- .ci/tests/examples/wait_for.py | 32 +- fedn/fedn/common/tracer/mongotracer.py | 16 + fedn/fedn/network/api/interface.py | 635 ++++++++----- fedn/fedn/network/api/network.py | 2 +- fedn/fedn/network/api/server.py | 221 +++-- fedn/fedn/network/combiner/server.py | 6 + fedn/fedn/network/controller/control.py | 237 +++-- fedn/fedn/network/controller/controlbase.py | 124 ++- fedn/fedn/network/dashboard/restservice.py | 855 +++++++++++------- .../network/dashboard/templates/events.html | 65 +- .../network/statestore/mongostatestore.py | 490 +++++++--- 11 files changed, 1746 insertions(+), 937 deletions(-) diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 7fa75506d..dc3345da0 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -37,21 +37,31 @@ def _test_rounds(n_rounds): return n == n_rounds -def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8090'): +def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092'): try: - resp = requests.get( - f'http://{reducer_host}:{reducer_port}/netgraph', verify=False) + + endpoint = "list_clients" if node_type == "client" else "list_combiners" + + response = requests.get( + f'http://{reducer_host}:{reducer_port}/{endpoint}', verify=False) + + if response.status_code == 200: + + data = json.loads(response.content) + + count = 0 + if node_type == "client": + arr = data.get('result') + count = sum(element.get('status') == "online" for element in arr) + else: + count = data.get('count') + + _eprint(f'Active {node_type}s: {count}.') + return count == n_nodes + except Exception as e: _eprint(f'Reques exception econuntered: {e}.') return False - if resp.status_code == 200: - gr = json.loads(resp.content) - n = sum(values.get('type') == node_type and values.get( - 'status') == 'active' for values in gr['nodes']) - _eprint(f'Active {node_type}s: {n}.') - return n == n_nodes - _eprint(f'Reducer returned {resp.status_code}.') - return False def rounds(n_rounds=3): diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 92af569ea..0a3e28cdc 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -1,4 +1,5 @@ import uuid +from datetime import datetime from google.protobuf.json_format import MessageToDict @@ -18,6 +19,7 @@ def __init__(self, mongo_config, network_id): self.rounds = self.mdb['control.rounds'] self.sessions = self.mdb['control.sessions'] self.validations = self.mdb['control.validations'] + self.clients = self.mdb['network.clients'] except Exception as e: print("FAILED TO CONNECT TO MONGO, {}".format(e), flush=True) self.status = None @@ -82,3 +84,17 @@ def set_round_data(self, round_data): """ self.rounds.update_one({'round_id': str(round_data['round_id'])}, { '$push': {'reducer': round_data}}, True) + + def update_client_status(self, client_name, status): + """ Update client status in statestore. + :param client_name: The client name + :type client_name: str + :param status: The client status + :type status: str + :return: None + """ + datetime_now = datetime.now() + filter_query = {"name": client_name} + + update_query = {"$set": {"last_seen": datetime_now, "status": status}} + self.clients.update_one(filter_query, update_query) diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index f05222e87..3a20e3502 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -1,44 +1,41 @@ import base64 import copy -import json import os import threading from io import BytesIO -from bson import json_util from flask import jsonify, send_from_directory from werkzeug.utils import secure_filename from fedn.common.config import get_controller_config, get_network_config from fedn.network.combiner.interfaces import (CombinerInterface, CombinerUnavailableError) +from fedn.network.dashboard.plots import Plot from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha -__all__ = 'API', +__all__ = ("API",) class API: - """ The API class is a wrapper for the statestore. It is used to expose the statestore to the network API. """ + """The API class is a wrapper for the statestore. It is used to expose the statestore to the network API.""" def __init__(self, statestore, control): self.statestore = statestore self.control = control - self.name = 'api' + self.name = "api" def _to_dict(self): - """ Convert the object to a dict. + """Convert the object to a dict. ::return: The object as a dict. ::rtype: dict """ - data = { - 'name': self.name - } + data = {"name": self.name} return data def _get_combiner_report(self, combiner_id): - """ Get report response from combiner. + """Get report response from combiner. :param combiner_id: The combiner id to get report response from. :type combiner_id: str @@ -50,8 +47,10 @@ def _get_combiner_report(self, combiner_id): report = combiner.report return report - def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={'gz', 'bz2', 'tar', 'zip', 'tgz'}): - """ Check if file extension is allowed. + def _allowed_file_extension( + self, filename, ALLOWED_EXTENSIONS={"gz", "bz2", "tar", "zip", "tgz"} + ): + """Check if file extension is allowed. :param filename: The filename to check. :type filename: str @@ -59,32 +58,40 @@ def _allowed_file_extension(self, filename, ALLOWED_EXTENSIONS={'gz', 'bz2', 'ta :rtype: bool """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) - def get_all_clients(self): - """ Get all clients from the statestore. + def get_clients(self, limit=None, skip=None, status=False): + """Get all clients from the statestore. :return: All clients as a json response. :rtype: :class:`flask.Response` """ # Will return list of ObjectId - clients_objects = self.statestore.list_clients() - payload = {} - for object in clients_objects: - id = object['name'] - info = {"combiner": object['combiner'], - "combiner_preferred": object['combiner_preferred'], - "ip": object['ip'], - "updated_at": object['updated_at'], - "status": object['status'], - } - payload[id] = info + response = self.statestore.list_clients(limit, skip, status) + + arr = [] + + for element in response["result"]: + obj = { + "id": element["name"], + "combiner": element["combiner"], + "combiner_preferred": element["combiner_preferred"], + "ip": element["ip"], + "status": element["status"], + "last_seen": element["last_seen"], + } - return jsonify(payload) + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) def get_active_clients(self, combiner_id): - """ Get all active clients, i.e that are assigned to a combiner. + """Get all active clients, i.e that are assigned to a combiner. A report request to the combiner is neccessary to determine if a client is active or not. :param combiner_id: The combiner id to get active clients for. @@ -95,34 +102,42 @@ def get_active_clients(self, combiner_id): # Get combiner interface object combiner = self.control.network.get_combiner(combiner_id) if combiner is None: - return jsonify({"success": False, "message": f"Combiner {combiner_id} not found."}), 404 + return ( + jsonify( + { + "success": False, + "message": f"Combiner {combiner_id} not found.", + } + ), + 404, + ) response = combiner.list_active_clients() return response - def get_all_combiners(self): - """ Get all combiners from the statestore. + def get_all_combiners(self, limit=None, skip=None): + """Get all combiners from the statestore. :return: All combiners as a json response. :rtype: :class:`flask.Response` """ # Will return list of ObjectId - combiner_objects = self.statestore.get_combiners() - payload = {} - for object in combiner_objects: - id = object['name'] - info = {"address": object['address'], - "fqdn": object['fqdn'], - "parent_reducer": object['parent']["name"], - "port": object['port'], - "report": object['report'], - "updated_at": object['updated_at'], - } - payload[id] = info + projection = {"name": True, "updated_at": True} + response = self.statestore.get_combiners(limit, skip, projection=projection) + arr = [] + for element in response["result"]: + obj = { + "name": element["name"], + "updated_at": element["updated_at"], + } - return jsonify(payload) + arr.append(obj) + + result = {"result": arr, "count": response["count"]} + + return jsonify(result) def get_combiner(self, combiner_id): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. :param combiner_id: The combiner id to get. :type combiner_id: str @@ -132,36 +147,42 @@ def get_combiner(self, combiner_id): # Will return ObjectId object = self.statestore.get_combiner(combiner_id) payload = {} - id = object['name'] - info = {"address": object['address'], - "fqdn": object['fqdn'], - "parent_reducer": object['parent']["name"], - "port": object['port'], - "report": object['report'], - "updated_at": object['updated_at'], - } + id = object["name"] + info = { + "address": object["address"], + "fqdn": object["fqdn"], + "parent_reducer": object["parent"]["name"], + "port": object["port"], + "report": object["report"], + "updated_at": object["updated_at"], + } payload[id] = info return jsonify(payload) - def get_all_sessions(self): - """ Get all sessions from the statestore. + def get_all_sessions(self, limit=None, skip=None): + """Get all sessions from the statestore. :return: All sessions as a json response. :rtype: :class:`flask.Response` """ - sessions_objects = self.statestore.get_sessions() - if sessions_objects is None: - return jsonify({"success": False, "message": "No sessions found."}), 404 - payload = {} - for object in sessions_objects: - id = object['session_id'] - info = object['session_config'][0] - payload[id] = info - return jsonify(payload) + sessions_object = self.statestore.get_sessions(limit, skip) + if sessions_object is None: + return ( + jsonify({"success": False, "message": "No sessions found."}), + 404, + ) + arr = [] + for element in sessions_object["result"]: + obj = element["session_config"][0] + arr.append(obj) + + result = {"result": arr, "count": sessions_object["count"]} + + return jsonify(result) def get_session(self, session_id): - """ Get a session from the statestore. + """Get a session from the statestore. :param session_id: The session id to get. :type session_id: str @@ -170,15 +191,23 @@ def get_session(self, session_id): """ session_object = self.statestore.get_session(session_id) if session_object is None: - return jsonify({"success": False, "message": f"Session {session_id} not found."}), 404 + return ( + jsonify( + { + "success": False, + "message": f"Session {session_id} not found.", + } + ), + 404, + ) payload = {} - id = session_object['session_id'] - info = session_object['session_config'][0] + id = session_object["session_id"] + info = session_object["session_config"][0] payload[id] = info return jsonify(payload) def set_compute_package(self, file, helper_type): - """ Set the compute package in the statestore. + """Set the compute package in the statestore. :param file: The compute package to set. :type file: file @@ -189,24 +218,42 @@ def set_compute_package(self, file, helper_type): if file and self._allowed_file_extension(file.filename): filename = secure_filename(file.filename) # TODO: make configurable, perhaps in config.py or package.py - file_path = os.path.join( - '/app/client/package/', filename) + file_path = os.path.join("/app/client/package/", filename) file.save(file_path) - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: - return jsonify({"success": False, "message": "Reducer is in instructing or monitoring state." - "Cannot set compute package."}), 400 + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): + return ( + jsonify( + { + "success": False, + "message": "Reducer is in instructing or monitoring state." + "Cannot set compute package.", + } + ), + 400, + ) self.control.set_compute_package(filename, file_path) self.statestore.set_helper(helper_type) success = self.statestore.set_compute_package(filename) if not success: - return jsonify({"success": False, "message": "Failed to set compute package."}), 400 + return ( + jsonify( + { + "success": False, + "message": "Failed to set compute package.", + } + ), + 400, + ) return jsonify({"success": True, "message": "Compute package set."}) def _get_compute_package_name(self): - """ Get the compute package name from the statestore. + """Get the compute package name from the statestore. :return: The compute package name. :rtype: str @@ -217,32 +264,38 @@ def _get_compute_package_name(self): return None, message else: try: - name = package_objects['filename'] + name = package_objects["filename"] except KeyError as e: message = "No compute package found. Key error." print(e) return None, message - return name, 'success' + return name, "success" def get_compute_package(self): - """ Get the compute package from the statestore. + """Get the compute package from the statestore. :return: The compute package as a json response. :rtype: :class:`flask.Response` """ package_object = self.statestore.get_compute_package() if package_object is None: - return jsonify({"success": False, "message": "No compute package found."}), 404 + return ( + jsonify( + {"success": False, "message": "No compute package found."} + ), + 404, + ) payload = {} - id = str(package_object['_id']) - info = {"filename": package_object['filename'], - "helper": package_object['helper'], - } + id = str(package_object["_id"]) + info = { + "filename": package_object["filename"], + "helper": package_object["helper"], + } payload[id] = info return jsonify(payload) def download_compute_package(self, name): - """ Download the compute package. + """Download the compute package. :return: The compute package as a json object. :rtype: :class:`flask.Response` @@ -255,23 +308,27 @@ def download_compute_package(self, name): mutex = threading.Lock() mutex.acquire() # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory('/app/client/package/', name, as_attachment=True) + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) except Exception: try: data = self.control.get_compute_package(name) # TODO: make configurable, perhaps in config.py or package.py - file_path = os.path.join('/app/client/package/', name) - with open(file_path, 'wb') as fh: + file_path = os.path.join("/app/client/package/", name) + with open(file_path, "wb") as fh: fh.write(data) # TODO: make configurable, perhaps in config.py or package.py - return send_from_directory('/app/client/package/', name, as_attachment=True) + return send_from_directory( + "/app/client/package/", name, as_attachment=True + ) except Exception: raise finally: mutex.release() def _create_checksum(self, name=None): - """ Create the checksum of the compute package. + """Create the checksum of the compute package. :param name: The name of the compute package. :type name: str @@ -282,17 +339,19 @@ def _create_checksum(self, name=None): if name is None: name, message = self._get_compute_package_name() if name is None: - return False, message, '' - file_path = os.path.join('/app/client/package/', name) # TODO: make configurable, perhaps in config.py or package.py + return False, message, "" + file_path = os.path.join( + "/app/client/package/", name + ) # TODO: make configurable, perhaps in config.py or package.py try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' - message = 'File not found.' + sum = "" + message = "File not found." return True, message, sum def get_checksum(self, name): - """ Get the checksum of the compute package. + """Get the checksum of the compute package. :param name: The name of the compute package. :type name: str @@ -303,66 +362,75 @@ def get_checksum(self, name): success, message, sum = self._create_checksum(name) if not success: return jsonify({"success": False, "message": message}), 404 - payload = {'checksum': sum} + payload = {"checksum": sum} return jsonify(payload) def get_controller_status(self): - """ Get the status of the controller. + """Get the status of the controller. :return: The status of the controller as a json object. :rtype: :py:class:`flask.Response` """ - return jsonify({'state': ReducerStateToString(self.control.state())}) + return jsonify({"state": ReducerStateToString(self.control.state())}) def get_events(self, **kwargs): - """ Get the events of the federated network. + """Get the events of the federated network. :return: The events as a json object. :rtype: :py:class:`flask.Response` """ - event_objects = self.statestore.get_events(**kwargs) - if event_objects is None: - return jsonify({"success": False, "message": "No events found."}), 404 - json_docs = [] - for doc in self.statestore.get_events(**kwargs): - json_doc = json.dumps(doc, default=json_util.default) - json_docs.append(json_doc) + response = self.statestore.get_events(**kwargs) - json_docs.reverse() - return jsonify({'events': json_docs}) + result = response["result"] + if result is None: + return ( + jsonify({"success": False, "message": "No events found."}), + 404, + ) + + events = [] + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) def get_all_validations(self, **kwargs): - """ Get all validations from the statestore. + """Get all validations from the statestore. :return: All validations as a json response. :rtype: :class:`flask.Response` """ validations_objects = self.statestore.get_validations(**kwargs) if validations_objects is None: - return jsonify( - { - "success": False, - "message": "No validations found.", - "filter_used": kwargs - } - ), 404 + return ( + jsonify( + { + "success": False, + "message": "No validations found.", + "filter_used": kwargs, + } + ), + 404, + ) payload = {} for object in validations_objects: - id = str(object['_id']) + id = str(object["_id"]) info = { - 'model_id': object['modelId'], - 'data': object['data'], - 'timestamp': object['timestamp'], - 'meta': object['meta'], - 'sender': object['sender'], - 'receiver': object['receiver'], + "model_id": object["modelId"], + "data": object["data"], + "timestamp": object["timestamp"], + "meta": object["meta"], + "sender": object["sender"], + "receiver": object["receiver"], } payload[id] = info return jsonify(payload) - def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, port): - """ Add a combiner to the network. + def add_combiner( + self, combiner_id, secure_grpc, address, remote_addr, fqdn, port + ): + """Add a combiner to the network. :param combiner_id: The combiner id to add. :type combiner_id: str @@ -383,18 +451,20 @@ def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, por """ # TODO: Any more required check for config? Formerly based on status: "retry" if not self.control.idle(): - return jsonify({ - 'success': False, - 'status': 'retry', - 'message': 'Conroller is not in idle state, try again later. ' - } + return jsonify( + { + "success": False, + "status": "retry", + "message": "Conroller is not in idle state, try again later. ", + } ) # Check if combiner already exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - if secure_grpc == 'True': + if secure_grpc == "True": certificate, key = self.certificate_manager.get_or_create( - address).get_keypair_raw() + address + ).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -410,29 +480,32 @@ def add_combiner(self, combiner_id, secure_grpc, address, remote_addr, fqdn, por port=port, certificate=copy.deepcopy(certificate), key=copy.deepcopy(key), - ip=remote_addr) + ip=remote_addr, + ) self.control.network.add_combiner(combiner_interface) # Check combiner now exists combiner = self.control.network.get_combiner(combiner_id) if not combiner: - return jsonify({'success': False, 'message': 'Combiner not added.'}) + return jsonify( + {"success": False, "message": "Combiner not added."} + ) payload = { - 'success': True, - 'message': 'Combiner added successfully.', - 'status': 'added', - 'storage': self.statestore.get_storage_backend(), - 'statestore': self.statestore.get_config(), - 'certificate': combiner.get_certificate(), - 'key': combiner.get_key() + "success": True, + "message": "Combiner added successfully.", + "status": "added", + "storage": self.statestore.get_storage_backend(), + "statestore": self.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), } return jsonify(payload) def add_client(self, client_id, preferred_combiner, remote_addr): - """ Add a client to the network. + """Add a client to the network. :param client_id: The client id to add. :type client_id: str @@ -444,26 +517,46 @@ def add_client(self, client_id, preferred_combiner, remote_addr): # Check if package has been set package_object = self.statestore.get_compute_package() if package_object is None: - return jsonify({'success': False, 'status': 'retry', 'message': 'No compute package found. Set package in controller.'}), 203 + return ( + jsonify( + { + "success": False, + "status": "retry", + "message": "No compute package found. Set package in controller.", + } + ), + 203, + ) # Assign client to combiner if preferred_combiner: combiner = self.control.network.get_combiner(preferred_combiner) if combiner is None: - return jsonify({'success': False, - 'message': f'Combiner {preferred_combiner} not found or unavailable.'}), 400 + return ( + jsonify( + { + "success": False, + "message": f"Combiner {preferred_combiner} not found or unavailable.", + } + ), + 400, + ) else: combiner = self.control.network.find_available_combiner() if combiner is None: - return jsonify({'success': False, - 'message': 'No combiner available.'}), 400 + return ( + jsonify( + {"success": False, "message": "No combiner available."} + ), + 400, + ) client_config = { - 'name': client_id, - 'combiner_preferred': preferred_combiner, - 'combiner': combiner.name, - 'ip': remote_addr, - 'status': 'available' + "name": client_id, + "combiner_preferred": preferred_combiner, + "combiner": combiner.name, + "ip": remote_addr, + "status": "available", } # Add client to network self.control.network.add_client(client_config) @@ -471,38 +564,36 @@ def add_client(self, client_id, preferred_combiner, remote_addr): # Setup response containing information about the combiner for assinging the client if combiner.certificate: cert_b64 = base64.b64encode(combiner.certificate) - cert = str(cert_b64).split('\'')[1] + cert = str(cert_b64).split("'")[1] else: cert = None payload = { - 'status': 'assigned', - 'host': combiner.address, - 'fqdn': combiner.fqdn, - 'package': 'remote', # TODO: Make this configurable - 'ip': combiner.ip, - 'port': combiner.port, - 'certificate': cert, - 'helper_type': self.control.statestore.get_helper() + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": "remote", # TODO: Make this configurable + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "helper_type": self.control.statestore.get_helper(), } print("Seding payload: ", payload, flush=True) return jsonify(payload) def get_initial_model(self): - """ Get the initial model from the statestore. + """Get the initial model from the statestore. :return: The initial model as a json response. :rtype: :class:`flask.Response` """ model_id = self.statestore.get_initial_model() - payload = { - 'model_id': model_id - } + payload = {"model_id": model_id} return jsonify(payload) def set_initial_model(self, file): - """ Add an initial model to the network. + """Add an initial model to the network. :param file: The initial model to add. :type file: file @@ -520,27 +611,47 @@ def set_initial_model(self, file): self.control.commit(file.filename, model) except Exception as e: print(e, flush=True) - return jsonify({'success': False, 'message': e}) + return jsonify({"success": False, "message": e}) - return jsonify({'success': True, 'message': 'Initial model added successfully.'}) + return jsonify( + {"success": True, "message": "Initial model added successfully."} + ) def get_latest_model(self): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. :return: The initial model as a json response. :rtype: :class:`flask.Response` """ if self.statestore.get_latest_model(): model_id = self.statestore.get_latest_model() - payload = { - 'model_id': model_id - } + payload = {"model_id": model_id} return jsonify(payload) else: - return jsonify({'success': False, 'message': 'No initial model set.'}) + return jsonify( + {"success": False, "message": "No initial model set."} + ) + + def get_models(self, session_id=None, limit=None, skip=None): + result = self.statestore.list_models(session_id, limit, skip) + + if result is None: + return ( + jsonify({"success": False, "message": "No models found."}), + 404, + ) + + arr = [] + + for model in result["result"]: + arr.append(model) + + result = {"result": arr, "count": result["count"]} + + return jsonify(result) def get_model_trail(self): - """ Get the model trail for a given session. + """Get the model trail for a given session. :param session: The session id to get the model trail for. :type session: str @@ -551,38 +662,41 @@ def get_model_trail(self): if model_info: return jsonify(model_info) else: - return jsonify({'success': False, 'message': 'No model trail available.'}) + return jsonify( + {"success": False, "message": "No model trail available."} + ) def get_all_rounds(self): - """ Get all rounds. + """Get all rounds. :return: The rounds as json response. :rtype: :class:`flask.Response` """ rounds_objects = self.statestore.get_rounds() if rounds_objects is None: - jsonify({'success': False, 'message': 'No rounds available.'}) + jsonify({"success": False, "message": "No rounds available."}) payload = {} for object in rounds_objects: - id = object['round_id'] - if 'reducer' in object.keys(): - reducer = object['reducer'] + id = object["round_id"] + if "reducer" in object.keys(): + reducer = object["reducer"] else: reducer = None - if 'combiners' in object.keys(): - combiners = object['combiners'] + if "combiners" in object.keys(): + combiners = object["combiners"] else: combiners = None - info = {'reducer': reducer, - 'combiners': combiners, - } + info = { + "reducer": reducer, + "combiners": combiners, + } payload[id] = info else: return jsonify(payload) def get_round(self, round_id): - """ Get a round. + """Get a round. :param round_id: The round id to get. :type round_id: str @@ -591,38 +705,100 @@ def get_round(self, round_id): """ round_object = self.statestore.get_round(round_id) if round_object is None: - return jsonify({'success': False, 'message': 'Round not found.'}) + return jsonify({"success": False, "message": "Round not found."}) payload = { - 'round_id': round_object['round_id'], - 'reducer': round_object['reducer'], - 'combiners': round_object['combiners'], + "round_id": round_object["round_id"], + "reducer": round_object["reducer"], + "combiners": round_object["combiners"], } return jsonify(payload) def get_client_config(self, checksum=True): - """ Get the client config. + """Get the client config. :return: The client config as json response. :rtype: :py:class:`flask.Response` """ config = get_controller_config() network_id = get_network_config() - port = config['port'] - host = config['host'] + port = config["port"] + host = config["host"] payload = { - 'network_id': network_id, - 'discover_host': host, - 'discover_port': port, + "network_id": network_id, + "discover_host": host, + "discover_port": port, } if checksum: success, _, checksum_str = self._create_checksum() if success: - payload['checksum'] = checksum_str + payload["checksum"] = checksum_str return jsonify(payload) - def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_size=-1, delete_models=False, - validate=True, helper='keras', min_clients=1, requested_clients=8): - """ Start a session. + def get_plot_data(self, feature=None): + """Get plot data. + + :return: The plot data as json response. + :rtype: :py:class:`flask.Response` + """ + + plot = Plot(self.control.statestore) + + try: + valid_metrics = plot.fetch_valid_metrics() + feature = feature or valid_metrics[0] + box_plot = plot.create_box_plot(feature) + except Exception as e: + valid_metrics = None + box_plot = None + print(e, flush=True) + + result = { + "valid_metrics": valid_metrics, + "box_plot": box_plot, + } + + return jsonify(result) + + def list_combiners_data(self, combiners): + """Get combiners data. + + :param combiners: The combiners to get data for. + :type combiners: list + :return: The combiners data as json response. + :rtype: :py:class:`flask.Response` + """ + + response = self.statestore.list_combiners_data(combiners) + + arr = [] + + # order list by combiner name + for element in response: + + obj = { + "combiner": element["_id"], + "count": element["count"], + } + + arr.append(obj) + + result = {"result": arr} + + return jsonify(result) + + def start_session( + self, + session_id, + rounds=5, + round_timeout=180, + round_buffer_size=-1, + delete_models=False, + validate=True, + helper="keras", + min_clients=1, + requested_clients=8, + ): + """Start a session. :param session_id: The session id to start. :type session_id: str @@ -646,18 +822,22 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si # Check if session already exists session = self.statestore.get_session(session_id) if session: - return jsonify({'success': False, 'message': 'Session already exists.'}) + return jsonify( + {"success": False, "message": "Session already exists."} + ) # Check if session is running if self.control.state() == ReducerState.monitoring: - return jsonify({'success': False, 'message': 'A session is already running.'}) + return jsonify( + {"success": False, "message": "A session is already running."} + ) # Check available clients per combiner clients_available = 0 for combiner in self.control.network.get_combiners(): try: combiner_state = combiner.report() - nr_active_clients = combiner_state['nr_active_clients'] + nr_active_clients = combiner_state["nr_active_clients"] clients_available = clients_available + int(nr_active_clients) except CombinerUnavailableError as e: # TODO: Handle unavailable combiner, stop session or continue? @@ -665,11 +845,16 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si continue if clients_available < min_clients: - return jsonify({'success': False, 'message': 'Not enough clients available to start session.'}) + return jsonify( + { + "success": False, + "message": "Not enough clients available to start session.", + } + ) # Check if validate is string and convert to bool if isinstance(validate, str): - if validate.lower() == 'true': + if validate.lower() == "true": validate = True else: validate = False @@ -678,22 +863,30 @@ def start_session(self, session_id, rounds=5, round_timeout=180, round_buffer_si model_id = self.statestore.get_latest_model() # Setup session config - session_config = {'session_id': session_id, - 'round_timeout': round_timeout, - 'buffer_size': round_buffer_size, - 'model_id': model_id, - 'rounds': rounds, - 'delete_models_storage': delete_models, - 'clients_required': min_clients, - 'clients_requested': requested_clients, - 'task': (''), - 'validate': validate, - 'helper_type': helper - } + session_config = { + "session_id": session_id, + "round_timeout": round_timeout, + "buffer_size": round_buffer_size, + "model_id": model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": min_clients, + "clients_requested": requested_clients, + "task": (""), + "validate": validate, + "helper_type": helper, + } # Start session - threading.Thread(target=self.control.session, - args=(session_config,)).start() + threading.Thread( + target=self.control.session, args=(session_config,) + ).start() # Return success response - return jsonify({'success': True, 'message': 'Session started successfully.', "config": session_config}) + return jsonify( + { + "success": True, + "message": "Session started successfully.", + "config": session_config, + } + ) diff --git a/fedn/fedn/network/api/network.py b/fedn/fedn/network/api/network.py index 26b366a76..6fcaad053 100644 --- a/fedn/fedn/network/api/network.py +++ b/fedn/fedn/network/api/network.py @@ -46,7 +46,7 @@ def get_combiners(self): """ data = self.statestore.get_combiners() combiners = [] - for c in data: + for c in data["result"]: if c['certificate']: cert = base64.b64decode(c['certificate']) key = base64.b64decode(c['key']) diff --git a/fedn/fedn/network/api/server.py b/fedn/fedn/network/api/server.py index 4e0e93775..cfb91bece 100644 --- a/fedn/fedn/network/api/server.py +++ b/fedn/fedn/network/api/server.py @@ -10,18 +10,16 @@ network_id = get_network_config() modelstorage_config = get_modelstorage_config() statestore = MongoStateStore( - network_id, - statestore_config['mongo_config'], - modelstorage_config + network_id, statestore_config["mongo_config"], modelstorage_config ) control = Control(statestore=statestore) api = API(statestore, control) app = Flask(__name__) -@app.route('/get_model_trail', methods=['GET']) +@app.route("/get_model_trail", methods=["GET"]) def get_model_trail(): - """ Get the model trail for a given session. + """Get the model trail for a given session. param: session: The session id to get the model trail for. type: session: str return: The model trail for the given session as a json object. @@ -30,9 +28,29 @@ def get_model_trail(): return api.get_model_trail() -@app.route('/delete_model_trail', methods=['GET', 'POST']) +@app.route("/list_models", methods=["GET"]) +def list_models(): + """Get models from the statestore. + param: + session_id: The session id to get the model trail for. + limit: The maximum number of models to return. + type: limit: int + param: skip: The number of models to skip. + type: skip: int + Returns: + _type_: json + """ + + session_id = request.args.get("session_id", None) + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_models(session_id, limit, skip) + + +@app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): - """ Delete the model trail for a given session. + """Delete the model trail for a given session. param: session: The session id to delete the model trail for. type: session: str return: The response from the statestore. @@ -41,78 +59,93 @@ def delete_model_trail(): return jsonify({"message": "Not implemented"}), 501 -@app.route('/list_clients', methods=['GET']) +@app.route("/list_clients", methods=["GET"]) def list_clients(): - """ Get all clients from the statestore. + """Get all clients from the statestore. return: All clients as a json object. rtype: json """ - return api.get_all_clients() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + status = request.args.get("status", None) + + return api.get_clients(limit, skip, status) -@app.route('/get_active_clients', methods=['GET']) + +@app.route("/get_active_clients", methods=["GET"]) def get_active_clients(): - """ Get all active clients from the statestore. + """Get all active clients from the statestore. param: combiner_id: The combiner id to get active clients for. type: combiner_id: str return: All active clients as a json object. rtype: json """ - combiner_id = request.args.get('combiner', None) + combiner_id = request.args.get("combiner", None) if combiner_id is None: - return jsonify({"success": False, "message": "Missing combiner id."}), 400 + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) return api.get_active_clients(combiner_id) -@app.route('/list_combiners', methods=['GET']) +@app.route("/list_combiners", methods=["GET"]) def list_combiners(): - """ Get all combiners in the network. + """Get all combiners in the network. return: All combiners as a json object. rtype: json """ - return api.get_all_combiners() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_combiners(limit, skip) -@app.route('/get_combiner', methods=['GET']) + +@app.route("/get_combiner", methods=["GET"]) def get_combiner(): - """ Get a combiner from the statestore. + """Get a combiner from the statestore. param: combiner_id: The combiner id to get. type: combiner_id: str return: The combiner as a json object. rtype: json """ - combiner_id = request.args.get('combiner', None) + combiner_id = request.args.get("combiner", None) if combiner_id is None: - return jsonify({"success": False, "message": "Missing combiner id."}), 400 + return ( + jsonify({"success": False, "message": "Missing combiner id."}), + 400, + ) return api.get_combiner(combiner_id) -@app.route('/list_rounds', methods=['GET']) +@app.route("/list_rounds", methods=["GET"]) def list_rounds(): - """ Get all rounds from the statestore. + """Get all rounds from the statestore. return: All rounds as a json object. rtype: json """ return api.get_all_rounds() -@app.route('/get_round', methods=['GET']) +@app.route("/get_round", methods=["GET"]) def get_round(): - """ Get a round from the statestore. + """Get a round from the statestore. param: round_id: The round id to get. type: round_id: str return: The round as a json object. rtype: json """ - round_id = request.args.get('round_id', None) + round_id = request.args.get("round_id", None) if round_id is None: return jsonify({"success": False, "message": "Missing round id."}), 400 return api.get_round(round_id) -@app.route('/start_session', methods=['GET', 'POST']) +@app.route("/start_session", methods=["GET", "POST"]) def start_session(): - """ Start a new session. + """Start a new session. return: The response from control. rtype: json """ @@ -120,30 +153,36 @@ def start_session(): return api.start_session(**json_data) -@app.route('/list_sessions', methods=['GET']) +@app.route("/list_sessions", methods=["GET"]) def list_sessions(): - """ Get all sessions from the statestore. + """Get all sessions from the statestore. return: All sessions as a json object. rtype: json """ - return api.get_all_sessions() + limit = request.args.get("limit", None) + skip = request.args.get("skip", None) + + return api.get_all_sessions(limit, skip) -@app.route('/get_session', methods=['GET']) +@app.route("/get_session", methods=["GET"]) def get_session(): - """ Get a session from the statestore. + """Get a session from the statestore. param: session_id: The session id to get. type: session_id: str return: The session as a json object. rtype: json """ - session_id = request.args.get('session_id', None) + session_id = request.args.get("session_id", None) if session_id is None: - return jsonify({"success": False, "message": "Missing session id."}), 400 + return ( + jsonify({"success": False, "message": "Missing session id."}), + 400, + ) return api.get_session(session_id) -@app.route('/set_package', methods=['POST']) +@app.route("/set_package", methods=["POST"]) def set_package(): """ Set the compute package in the statestore. Usage with curl: @@ -157,64 +196,68 @@ def set_package(): return: The response from the statestore. rtype: json """ - helper_type = request.form.get('helper', None) + helper_type = request.form.get("helper", None) if helper_type is None: - return jsonify({"success": False, "message": "Missing helper type."}), 400 + return ( + jsonify({"success": False, "message": "Missing helper type."}), + 400, + ) try: - file = request.files['file'] + file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_compute_package(file=file, helper_type=helper_type) -@app.route('/get_package', methods=['GET']) +@app.route("/get_package", methods=["GET"]) def get_package(): - """ Get the compute package from the statestore. + """Get the compute package from the statestore. return: The compute package as a json object. rtype: json """ return api.get_compute_package() -@app.route('/download_package', methods=['GET']) +@app.route("/download_package", methods=["GET"]) def download_package(): - """ Download the compute package. + """Download the compute package. return: The compute package as a json object. rtype: json """ - name = request.args.get('name', None) + name = request.args.get("name", None) return api.download_compute_package(name) -@app.route('/get_package_checksum', methods=['GET']) +@app.route("/get_package_checksum", methods=["GET"]) def get_package_checksum(): - name = request.args.get('name', None) + name = request.args.get("name", None) return api.get_checksum(name) -@app.route('/get_latest_model', methods=['GET']) +@app.route("/get_latest_model", methods=["GET"]) def get_latest_model(): - """ Get the latest model from the statestore. + """Get the latest model from the statestore. return: The initial model as a json object. rtype: json """ return api.get_latest_model() + # Get initial model endpoint -@app.route('/get_initial_model', methods=['GET']) +@app.route("/get_initial_model", methods=["GET"]) def get_initial_model(): - """ Get the initial model from the statestore. + """Get the initial model from the statestore. return: The initial model as a json object. rtype: json """ return api.get_initial_model() -@app.route('/set_initial_model', methods=['POST']) +@app.route("/set_initial_model", methods=["POST"]) def set_initial_model(): - """ Set the initial model in the statestore and upload to model repository. + """Set the initial model in the statestore and upload to model repository. Usage with curl: curl -k -X POST -F file=@seed.npz @@ -226,45 +269,46 @@ def set_initial_model(): rtype: json """ try: - file = request.files['file'] + file = request.files["file"] except KeyError: return jsonify({"success": False, "message": "Missing file."}), 400 return api.set_initial_model(file) -@app.route('/get_controller_status', methods=['GET']) +@app.route("/get_controller_status", methods=["GET"]) def get_controller_status(): - """ Get the status of the controller. + """Get the status of the controller. return: The status as a json object. rtype: json """ return api.get_controller_status() -@app.route('/get_client_config', methods=['GET']) +@app.route("/get_client_config", methods=["GET"]) def get_client_config(): - """ Get the client configuration. + """Get the client configuration. return: The client configuration as a json object. rtype: json """ - checksum = request.args.get('checksum', True) + checksum = request.args.get("checksum", True) return api.get_client_config(checksum) -@app.route('/get_events', methods=['GET']) +@app.route("/get_events", methods=["GET"]) def get_events(): - """ Get the events from the statestore. + """Get the events from the statestore. return: The events as a json object. rtype: json """ # TODO: except filter with request.get_json() kwargs = request.args.to_dict() + return api.get_events(**kwargs) -@app.route('/list_validations', methods=['GET']) +@app.route("/list_validations", methods=["GET"]) def list_validations(): - """ Get all validations from the statestore. + """Get all validations from the statestore. return: All validations as a json object. rtype: json """ @@ -273,9 +317,9 @@ def list_validations(): return api.get_all_validations(**kwargs) -@app.route('/add_combiner', methods=['POST']) +@app.route("/add_combiner", methods=["POST"]) def add_combiner(): - """ Add a combiner to the network. + """Add a combiner to the network. return: The response from the statestore. rtype: json """ @@ -284,13 +328,13 @@ def add_combiner(): try: response = api.add_combiner(**json_data, remote_addr=remote_addr) except TypeError as e: - return jsonify({'success': False, 'message': str(e)}), 400 + return jsonify({"success": False, "message": str(e)}), 400 return response -@app.route('/add_client', methods=['POST']) +@app.route("/add_client", methods=["POST"]) def add_client(): - """ Add a client to the network. + """Add a client to the network. return: The response from control. rtype: json """ @@ -300,12 +344,45 @@ def add_client(): try: response = api.add_client(**json_data, remote_addr=remote_addr) except TypeError as e: - return jsonify({'success': False, 'message': str(e)}), 400 + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/list_combiners_data", methods=["POST"]) +def list_combiners_data(): + """List data from combiners. + return: The response from control. + rtype: json + """ + + json_data = request.get_json() + + # expects a list of combiner names (strings) in an array + combiners = json_data.get("combiners", None) + + try: + response = api.list_combiners_data(combiners) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 + return response + + +@app.route("/get_plot_data", methods=["GET"]) +def get_plot_data(): + """Get plot data from the statestore. + rtype: json + """ + + try: + feature = request.args.get("feature", None) + response = api.get_plot_data(feature=feature) + except TypeError as e: + return jsonify({"success": False, "message": str(e)}), 400 return response -if __name__ == '__main__': +if __name__ == "__main__": config = get_controller_config() - port = config['port'] - debug = config['debug'] - app.run(debug=debug, port=port, host='0.0.0.0') + port = config["port"] + debug = config["debug"] + app.run(debug=debug, port=port, host="0.0.0.0") diff --git a/fedn/fedn/network/combiner/server.py b/fedn/fedn/network/combiner/server.py index 11d874ea6..7a9c87ff9 100644 --- a/fedn/fedn/network/combiner/server.py +++ b/fedn/fedn/network/combiner/server.py @@ -98,9 +98,11 @@ def __init__(self, config): break if status == Status.UnAuthorized: print(response, flush=True) + print("Status.UnAuthorized", flush=True) sys.exit("Exiting: Unauthorized") if status == Status.UnMatchedConfig: print(response, flush=True) + print("Status.UnMatchedConfig", flush=True) sys.exit("Exiting: Missing config") cert = announce_config['certificate'] @@ -712,12 +714,16 @@ def ModelUpdateRequestStream(self, response, context): self._send_status(status) + self.tracer.update_client_status(client.name, "online") + while context.is_active(): try: yield q.get(timeout=1.0) except queue.Empty: pass + self.tracer.update_client_status(client.name, "offline") + def ModelValidationStream(self, update, context): """ Model validation stream RPC endpoint. Update status for client is connecting to stream. diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index 5f5bc6634..a8e32333d 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -9,10 +9,10 @@ class UnsupportedStorageBackend(Exception): - """ Exception class for when storage backend is not supported. Passes """ + """Exception class for when storage backend is not supported. Passes""" def __init__(self, message): - """ Constructor method. + """Constructor method. :param message: The exception message. :type message: str @@ -23,46 +23,46 @@ def __init__(self, message): class MisconfiguredStorageBackend(Exception): - """ Exception class for when storage backend is misconfigured. + """Exception class for when storage backend is misconfigured. :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" self.message = message super().__init__(self.message) class NoModelException(Exception): - """ Exception class for when model is None + """Exception class for when model is None :param message: The exception message. :type message: str """ def __init__(self, message): - """ Constructor method.""" + """Constructor method.""" self.message = message super().__init__(self.message) class Control(ControlBase): - """ Controller, implementing the overall global training, validation and inference logic. + """Controller, implementing the overall global training, validation and inference logic. :param statestore: A StateStorage instance. :type statestore: class: `fedn.network.statestorebase.StateStorageBase` """ def __init__(self, statestore): - """ Constructor method.""" + """Constructor method.""" super().__init__(statestore) self.name = "DefaultControl" def session(self, config): - """ Execute a new training session. A session consists of one + """Execute a new training session. A session consists of one or several global rounds. All rounds in the same session have the same round_config. @@ -72,7 +72,10 @@ def session(self, config): """ if self._state == ReducerState.instructing: - print("Controller already in INSTRUCTING state. A session is in progress.", flush=True) + print( + "Controller already in INSTRUCTING state. A session is in progress.", + flush=True, + ) return if not self.statestore.get_latest_model(): @@ -82,11 +85,16 @@ def session(self, config): self._state = ReducerState.instructing # Must be called to set info in the db - config['committed_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + config["committed_at"] = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) self.new_session(config) if not self.statestore.get_latest_model(): - print("No model in model chain, please provide a seed model!", flush=True) + print( + "No model in model chain, please provide a seed model!", + flush=True, + ) self._state = ReducerState.monitoring last_round = int(self.get_latest_round_id()) @@ -96,7 +104,7 @@ def session(self, config): combiner.flush_model_update_queue() # Execute the rounds in this session - for round in range(1, int(config['rounds'] + 1)): + for round in range(1, int(config["rounds"] + 1)): # Increment the round number if last_round: @@ -107,10 +115,17 @@ def session(self, config): try: _, round_data = self.round(config, current_round) except TypeError as e: - print("Could not unpack data from round: {0}".format(e), flush=True) - - print("CONTROL: Round completed with status {}".format( - round_data['status']), flush=True) + print( + "Could not unpack data from round: {0}".format(e), + flush=True, + ) + + print( + "CONTROL: Round completed with status {}".format( + round_data["status"] + ), + flush=True, + ) self.tracer.set_round_data(round_data) @@ -118,7 +133,7 @@ def session(self, config): self._state = ReducerState.idle def round(self, session_config, round_id): - """ Execute a single global round. + """Execute a single global round. :param session_config: The session config. :type session_config: dict @@ -126,35 +141,42 @@ def round(self, session_config, round_id): :type round_id: str(int) """ - round_data = {'round_id': round_id} + round_data = {"round_id": round_id} if len(self.network.get_combiners()) < 1: print("REDUCER: No combiners connected!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data # 1. Assemble round config for this global round, # and check which combiners are able to participate # in the round. round_config = copy.deepcopy(session_config) - round_config['rounds'] = 1 - round_config['round_id'] = round_id - round_config['task'] = 'training' - round_config['model_id'] = self.statestore.get_latest_model() - round_config['helper_type'] = self.statestore.get_helper() + round_config["rounds"] = 1 + round_config["round_id"] = round_id + round_config["task"] = "training" + round_config["model_id"] = self.statestore.get_latest_model() + round_config["helper_type"] = self.statestore.get_helper() combiners = self.get_participating_combiners(round_config) round_start = self.evaluate_round_start_policy(combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) + round_data["status"] = "Failed" return None - round_data['round_config'] = round_config + round_data["round_config"] = round_config # 2. Ask participating combiners to coordinate model updates _ = self.request_model_updates(combiners) @@ -164,27 +186,37 @@ def round(self, session_config, round_id): # dict to store combiners that have successfully produced an updated model updated = {} # wait until all combiners have produced an updated model or until round timeout - print("CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id), flush=True) + print( + "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( + round_id=round_id + ), + flush=True, + ) while len(updated) < len(combiners): round = self.statestore.get_round(round_id) if round: print("CONTROL: Round found!", flush=True) # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round['combiners']: + for combiner in round["combiners"]: print(combiner, flush=True) - if combiner['status'] == 'Success': - if combiner['name'] not in updated.keys(): + if combiner["status"] == "Success": + if combiner["name"] not in updated.keys(): # Add combiner to updated dict - updated[combiner['name']] = combiner['model_id'] + updated[combiner["name"]] = combiner["model_id"] # Print combiner status - print("CONTROL: Combiner {name} status: {status}".format( - name=combiner['name'], status=combiner['status']), flush=True) + print( + "CONTROL: Combiner {name} status: {status}".format( + name=combiner["name"], status=combiner["status"] + ), + flush=True, + ) else: # Print every 10 seconds based on value of wait if wait % 10 == 0: - print("CONTROL: Waiting for round to complete...", flush=True) - if wait >= session_config['round_timeout']: + print( + "CONTROL: Waiting for round to complete...", flush=True + ) + if wait >= session_config["round_timeout"]: print("CONTROL: Round timeout! Exiting round...", flush=True) break # Update wait time used for timeout @@ -194,53 +226,77 @@ def round(self, session_config, round_id): round_valid = self.evaluate_round_validity_policy(updated) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data['status'] = 'Failed' + round_data["status"] = "Failed" return None, round_data print("CONTROL: Reducing models from combiners...", flush=True) # 3. Reduce combiner models into a global model try: model, data = self.reduce(updated) - round_data['reduce'] = data + round_data["reduce"] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print("CONTROL: Failed to reduce models from combiners: {}".format( - e), flush=True) - round_data['status'] = 'Failed' + print( + "CONTROL: Failed to reduce models from combiners: {}".format( + e + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data # 6. Commit the global model to model trail if model is not None: - print("CONTROL: Committing global model to model trail...", flush=True) + print( + "CONTROL: Committing global model to model trail...", + flush=True, + ) tic = time.time() model_id = uuid.uuid4() - self.commit(model_id, model) - round_data['time_commit'] = time.time() - tic - print("CONTROL: Done committing global model to model trail!", flush=True) + session_id = ( + session_config["session_id"] + if "session_id" in session_config + else None + ) + self.commit(model_id, model, session_id) + round_data["time_commit"] = time.time() - tic + print( + "CONTROL: Done committing global model to model trail!", + flush=True, + ) else: - print("REDUCER: failed to update model in round with config {}".format( - session_config), flush=True) - round_data['status'] = 'Failed' + print( + "REDUCER: failed to update model in round with config {}".format( + session_config + ), + flush=True, + ) + round_data["status"] = "Failed" return None, round_data - round_data['status'] = 'Success' + round_data["status"] = "Success" # 4. Trigger participating combiner nodes to execute a validation round for the current model - validate = session_config['validate'] + validate = session_config["validate"] if validate: combiner_config = copy.deepcopy(session_config) - combiner_config['round_id'] = round_id - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'validation' - combiner_config['helper_type'] = self.statestore.get_helper() + combiner_config["round_id"] = round_id + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "validation" + combiner_config["helper_type"] = self.statestore.get_helper() validating_combiners = self._select_participating_combiners( - combiner_config) + combiner_config + ) for combiner, combiner_config in validating_combiners: try: - print("CONTROL: Submitting validation round to combiner {}".format( - combiner), flush=True) + print( + "CONTROL: Submitting validation round to combiner {}".format( + combiner + ), + flush=True, + ) combiner.submit(combiner_config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) @@ -249,16 +305,16 @@ def round(self, session_config, round_id): return model_id, round_data def reduce(self, combiners): - """ Combine updated models from Combiner nodes into one global model. + """Combine updated models from Combiner nodes into one global model. :param combiners: dict of combiner names (key) and model IDs (value) to reduce :type combiners: dict """ meta = {} - meta['time_fetch_model'] = 0.0 - meta['time_load_model'] = 0.0 - meta['time_aggregate_model'] = 0.0 + meta["time_fetch_model"] = 0.0 + meta["time_load_model"] = 0.0 + meta["time_aggregate_model"] = 0.0 i = 1 model = None @@ -268,18 +324,25 @@ def reduce(self, combiners): return model, meta for name, model_id in combiners.items(): - # TODO: Handle inactive RPC error in get_model and raise specific error - print("REDUCER: Fetching model ({model_id}) from combiner {name}".format( - model_id=model_id, name=name), flush=True) + print( + "REDUCER: Fetching model ({model_id}) from combiner {name}".format( + model_id=model_id, name=name + ), + flush=True, + ) try: tic = time.time() combiner = self.get_combiner(name) data = combiner.get_model(model_id) - meta['time_fetch_model'] += (time.time() - tic) + meta["time_fetch_model"] += time.time() - tic except Exception as e: - print("REDUCER: Failed to fetch model from combiner {}: {}".format( - name, e), flush=True) + print( + "REDUCER: Failed to fetch model from combiner {}: {}".format( + name, e + ), + flush=True, + ) data = None if data is not None: @@ -288,21 +351,21 @@ def reduce(self, combiners): helper = self.get_helper() data.seek(0) model_next = helper.load(data) - meta['time_load_model'] += (time.time() - tic) + meta["time_load_model"] += time.time() - tic tic = time.time() model = helper.increment_average(model, model_next, i, i) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() data.seek(0) model = helper.load(data) - meta['time_aggregate_model'] += (time.time() - tic) + meta["time_aggregate_model"] += time.time() - tic i = i + 1 return model, meta def infer_instruct(self, config): - """ Main entrypoint for executing the inference compute plan. + """Main entrypoint for executing the inference compute plan. :param config: configuration for the inference round """ @@ -330,7 +393,7 @@ def infer_instruct(self, config): self.__state = ReducerState.idle def inference_round(self, config): - """ Execute an inference round. + """Execute an inference round. :param config: configuration for the inference round """ @@ -345,21 +408,27 @@ def inference_round(self, config): # Setup combiner configuration combiner_config = copy.deepcopy(config) - combiner_config['model_id'] = self.statestore.get_latest_model() - combiner_config['task'] = 'inference' - combiner_config['helper_type'] = self.statestore.get_framework() + combiner_config["model_id"] = self.statestore.get_latest_model() + combiner_config["task"] = "inference" + combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners( - combiner_config) + validating_combiners = self._select_round_combiners(combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) if round_start: - print("CONTROL: round start policy met, participating combiners {}".format( - validating_combiners), flush=True) + print( + "CONTROL: round start policy met, participating combiners {}".format( + validating_combiners + ), + flush=True, + ) else: - print("CONTROL: Round start policy not met, skipping round!", flush=True) + print( + "CONTROL: Round start policy not met, skipping round!", + flush=True, + ) return None # Synch combiners with latest model and trigger inference diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index e38d31e38..077620c14 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -11,7 +11,7 @@ from fedn.network.state import ReducerState # Maximum number of tries to connect to statestore and retrieve storage configuration -MAX_TRIES_BACKEND = os.getenv('MAX_TRIES_BACKEND', 10) +MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) class UnsupportedStorageBackend(Exception): @@ -27,7 +27,7 @@ class MisconfiguredHelper(Exception): class ControlBase(ABC): - """ Base class and interface for a global controller. + """Base class and interface for a global controller. Override this class to implement a global training strategy (control). :param statestore: The statestore object. @@ -36,7 +36,7 @@ class ControlBase(ABC): @abstractmethod def __init__(self, statestore): - """ Constructor. """ + """Constructor.""" self._state = ReducerState.setup self.statestore = statestore @@ -52,26 +52,36 @@ def __init__(self, statestore): not_ready = False else: print( - "REDUCER CONTROL: Storage backend not configured, waiting...", flush=True) + "REDUCER CONTROL: Storage backend not configured, waiting...", + flush=True, + ) sleep(5) tries += 1 if tries > MAX_TRIES_BACKEND: raise Exception except Exception: print( - "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", flush=True) + "REDUCER CONTROL: Failed to retrive storage configuration, exiting.", + flush=True, + ) raise MisconfiguredStorageBackend() - if storage_config['storage_type'] == 'S3': - self.model_repository = S3ModelRepository(storage_config['storage_config']) + if storage_config["storage_type"] == "S3": + self.model_repository = S3ModelRepository( + storage_config["storage_config"] + ) else: - print("REDUCER CONTROL: Unsupported storage backend, exiting.", flush=True) + print( + "REDUCER CONTROL: Unsupported storage backend, exiting.", + flush=True, + ) raise UnsupportedStorageBackend() # The tracer is a helper that manages state in the database backend statestore_config = statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], statestore_config["network_id"] + ) if self.statestore.is_inited(): self._state = ReducerState.idle @@ -89,7 +99,7 @@ def reduce(self, combiners): pass def get_helper(self): - """ Get a helper instance from global config. + """Get a helper instance from global config. :return: Helper instance. :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` @@ -97,11 +107,15 @@ def get_helper(self): helper_type = self.statestore.get_helper() helper = fedn.utils.helpers.get_helper(helper_type) if not helper: - raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) + raise MisconfiguredHelper( + "Unsupported helper type {}, please configure compute_package.helper !".format( + helper_type + ) + ) return helper def get_state(self): - """ Get the current state of the controller. + """Get the current state of the controller. :return: The current state. :rtype: :class:`fedn.network.state.ReducerState` @@ -109,7 +123,7 @@ def get_state(self): return self._state def idle(self): - """ Check if the controller is idle. + """Check if the controller is idle. :return: True if idle, False otherwise. :rtype: bool @@ -139,7 +153,7 @@ def get_latest_round_id(self): if not last_round: return 0 else: - return last_round['round_id'] + return last_round["round_id"] def get_latest_round(self): round = self.statestore.get_latest_round() @@ -153,27 +167,29 @@ def get_compute_package_name(self): definition = self.statestore.get_compute_package() if definition: try: - package_name = definition['filename'] + package_name = definition["filename"] return package_name except (IndexError, KeyError): print( - "No context filename set for compute context definition", flush=True) + "No context filename set for compute context definition", + flush=True, + ) return None else: return None def set_compute_package(self, filename, path): - """ Persist the configuration for the compute package. """ + """Persist the configuration for the compute package.""" self.model_repository.set_compute_package(filename, path) self.statestore.set_compute_package(filename) - def get_compute_package(self, compute_package=''): + def get_compute_package(self, compute_package=""): """ :param compute_package: :return: """ - if compute_package == '': + if compute_package == "": compute_package = self.get_compute_package_name() if compute_package: return self.model_repository.get_compute_package(compute_package) @@ -181,42 +197,49 @@ def get_compute_package(self, compute_package=''): return None def new_session(self, config): - """ Initialize a new session in backend db. """ + """Initialize a new session in backend db.""" if "session_id" not in config.keys(): session_id = uuid.uuid4() - config['session_id'] = str(session_id) + config["session_id"] = str(session_id) else: - session_id = config['session_id'] + session_id = config["session_id"] self.tracer.new_session(id=session_id) self.tracer.set_session_config(session_id, config) def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update. """ + """Call Combiner server RPC to get a model update.""" cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) cl.append((combiner, response)) return cl - def commit(self, model_id, model=None): - """ Commit a model to the global model trail. The model commited becomes the lastest consensus model. """ + def commit(self, model_id, model=None, session_id=None): + """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" helper = self.get_helper() if model is not None: - print("CONTROL: Saving model file temporarily to disk...", flush=True) + print( + "CONTROL: Saving model file temporarily to disk...", flush=True + ) outfile_name = helper.save(model) print("CONTROL: Uploading model to Minio...", flush=True) model_id = self.model_repository.set_model( - outfile_name, is_file=True) + outfile_name, is_file=True + ) print("CONTROL: Deleting temporary model file...", flush=True) os.unlink(outfile_name) - print("CONTROL: Committing model {} to global model trail in statestore...".format( - model_id), flush=True) - self.statestore.set_latest_model(model_id) + print( + "CONTROL: Committing model {} to global model trail in statestore...".format( + model_id + ), + flush=True, + ) + self.statestore.set_latest_model(model_id, session_id) def get_combiner(self, name): for combiner in self.network.get_combiners(): @@ -226,7 +249,7 @@ def get_combiner(self, name): def get_participating_combiners(self, combiner_round_config): """Assemble a list of combiners able to participate in a round as - descibed by combiner_round_config. + descibed by combiner_round_config. """ combiners = [] for combiner in self.network.get_combiners(): @@ -238,45 +261,47 @@ def get_participating_combiners(self, combiner_round_config): if combiner_state is not None: is_participating = self.evaluate_round_participation_policy( - combiner_round_config, combiner_state) + combiner_round_config, combiner_state + ) if is_participating: combiners.append((combiner, combiner_round_config)) return combiners - def evaluate_round_participation_policy(self, compute_plan, combiner_state): - """ Evaluate policy for combiner round-participation. - A combiner participates if it is responsive and reports enough - active clients to participate in the round. + def evaluate_round_participation_policy( + self, compute_plan, combiner_state + ): + """Evaluate policy for combiner round-participation. + A combiner participates if it is responsive and reports enough + active clients to participate in the round. """ - if compute_plan['task'] == 'training': - nr_active_clients = int(combiner_state['nr_active_trainers']) - elif compute_plan['task'] == 'validation': - nr_active_clients = int(combiner_state['nr_active_validators']) + if compute_plan["task"] == "training": + nr_active_clients = int(combiner_state["nr_active_trainers"]) + elif compute_plan["task"] == "validation": + nr_active_clients = int(combiner_state["nr_active_validators"]) else: print("Invalid task type!", flush=True) return False - if int(compute_plan['clients_required']) <= nr_active_clients: + if int(compute_plan["clients_required"]) <= nr_active_clients: return True else: return False def evaluate_round_start_policy(self, combiners): - """ Check if the policy to start a round is met. """ + """Check if the policy to start a round is met.""" if len(combiners) > 0: - return True else: return False def evaluate_round_validity_policy(self, combiners): - """ Check if the round should be seen as valid. + """Check if the round should be seen as valid. - At the end of the round, before committing a model to the global model trail, - we check if the round validity policy has been met. This can involve - e.g. asserting that a certain number of combiners have reported in an - updated model, or that criteria on model performance have been met. + At the end of the round, before committing a model to the global model trail, + we check if the round validity policy has been met. This can involve + e.g. asserting that a certain number of combiners have reported in an + updated model, or that criteria on model performance have been met. """ if combiners.keys() == []: return False @@ -294,7 +319,8 @@ def _select_participating_combiners(self, compute_plan): if combiner_state: is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state) + compute_plan, combiner_state + ) if is_participating: participating_combiners.append((combiner, compute_plan)) return participating_combiners diff --git a/fedn/fedn/network/dashboard/restservice.py b/fedn/fedn/network/dashboard/restservice.py index 3c272349e..14e7266bb 100644 --- a/fedn/fedn/network/dashboard/restservice.py +++ b/fedn/fedn/network/dashboard/restservice.py @@ -23,8 +23,8 @@ from fedn.network.state import ReducerState, ReducerStateToString from fedn.utils.checksum import sha -UPLOAD_FOLDER = '/app/client/package/' -ALLOWED_EXTENSIONS = {'gz', 'bz2', 'tar', 'zip', 'tgz'} +UPLOAD_FOLDER = "/app/client/package/" +ALLOWED_EXTENSIONS = {"gz", "bz2", "tar", "zip", "tgz"} def allowed_file(filename): @@ -33,8 +33,10 @@ def allowed_file(filename): :param filename: :return: """ - return '.' in filename and \ - filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return ( + "." in filename + and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS + ) def encode_auth_token(secret_key): @@ -43,16 +45,17 @@ def encode_auth_token(secret_key): """ try: payload = { - 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=90, seconds=0), - 'iat': datetime.datetime.utcnow(), - 'status': 'Success' + "exp": datetime.datetime.utcnow() + + datetime.timedelta(days=90, seconds=0), + "iat": datetime.datetime.utcnow(), + "status": "Success", } - token = jwt.encode( - payload, - secret_key, - algorithm='HS256' + token = jwt.encode(payload, secret_key, algorithm="HS256") + print( + "\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n".format( + token + ) ) - print('\n\n\nSECURE MODE ENABLED, USE TOKEN TO ACCESS REDUCER: **** {} ****\n\n\n'.format(token)) return token except Exception as e: return e @@ -64,56 +67,49 @@ def decode_auth_token(auth_token, secret): :return: string """ try: - payload = jwt.decode( - auth_token, - secret, - algorithms=['HS256'] - ) + payload = jwt.decode(auth_token, secret, algorithms=["HS256"]) return payload["status"] except jwt.ExpiredSignatureError as e: print(e) - return 'Token has expired.' + return "Token has expired." except jwt.InvalidTokenError as e: print(e) - return 'Invalid token.' + return "Invalid token." class ReducerRestService: - """ - - """ + """ """ def __init__(self, config, control, statestore, certificate_manager): - print("config object!: \n\n\n\n{}".format(config)) - if config['host']: - self.host = config['host'] + if config["host"]: + self.host = config["host"] else: self.host = None - self.name = config['name'] + self.name = config["name"] - self.port = config['port'] - self.network_id = config['name'] + '-network' + self.port = config["port"] + self.network_id = config["name"] + "-network" - if 'token' in config.keys(): + if "token" in config.keys(): self.token_auth_enabled = True else: self.token_auth_enabled = False - if 'secret_key' in config.keys(): - self.SECRET_KEY = config['secret_key'] + if "secret_key" in config.keys(): + self.SECRET_KEY = config["secret_key"] else: self.SECRET_KEY = None - if 'use_ssl' in config.keys(): - self.use_ssl = config['use_ssl'] + if "use_ssl" in config.keys(): + self.use_ssl = config["use_ssl"] self.remote_compute_package = config["remote_compute_package"] if self.remote_compute_package: - self.package = 'remote' + self.package = "remote" else: - self.package = 'local' + self.package = "local" self.control = control self.statestore = statestore @@ -125,9 +121,7 @@ def to_dict(self): :return: """ - data = { - 'name': self.name - } + data = {"name": self.name} return data def check_compute_package(self): @@ -165,24 +159,40 @@ def check_configured_response(self): :rtype: json """ if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Controller is not configured."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Controller is not configured.", + } + ) if not self.check_compute_package(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Compute package is not configured. Please upload the compute package."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Compute package is not configured. Please upload the compute package.", + } + ) if not self.check_initial_model(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Initial model is not configured. Please upload the model."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Initial model is not configured. Please upload the model.", + } + ) if not self.control.idle(): - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Conroller is not in idle state, try again later. "}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Conroller is not in idle state, try again later. ", + } + ) return None def check_configured(self): @@ -192,17 +202,29 @@ def check_configured(self): :return: Rendered html template or None """ if not self.check_compute_package(): - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=False, - message='Please set the compute package') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=False, + message="Please set the compute package", + ) if self.control.state() == ReducerState.setup: - return render_template('setup.html', client=self.name, state=ReducerStateToString(self.control.state()), - logs=None, refresh=True, - message='Warning. Reducer is not base-configured. please do so with config file.') + return render_template( + "setup.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + logs=None, + refresh=True, + message="Warning. Reducer is not base-configured. please do so with config file.", + ) if not self.check_initial_model(): - return render_template('setup_model.html', message="Please set the initial model.") + return render_template( + "setup_model.html", message="Please set the initial model." + ) return None @@ -216,31 +238,37 @@ def authorize(self, r, secret): """ try: # Get token - if 'Authorization' in r.headers: # header auth - request_token = r.headers.get('Authorization').split()[1] - elif 'token' in r.args: # args auth - request_token = str(r.args.get('token')) - elif 'fedn_token' in r.cookies: - request_token = r.cookies.get('fedn_token') + if "Authorization" in r.headers: # header auth + request_token = r.headers.get("Authorization").split()[1] + elif "token" in r.args: # args auth + request_token = str(r.args.get("token")) + elif "fedn_token" in r.cookies: + request_token = r.cookies.get("fedn_token") else: # no token provided - print('Authorization failed. No token provided.', flush=True) + print("Authorization failed. No token provided.", flush=True) abort(401) # Log token and secret print( - f'Secret: {secret}. Request token: {request_token}.', flush=True) + f"Secret: {secret}. Request token: {request_token}.", + flush=True, + ) # Authenticate status = decode_auth_token(request_token, secret) - if status == 'Success': + if status == "Success": return True else: - print('Authorization failed. Status: "{}"'.format( - status), flush=True) + print( + 'Authorization failed. Status: "{}"'.format(status), + flush=True, + ) abort(401) except Exception as e: - print('Authorization failed. Expection encountered: "{}".'.format( - e), flush=True) + print( + 'Authorization failed. Expection encountered: "{}".'.format(e), + flush=True, + ) abort(401) def run(self): @@ -250,10 +278,10 @@ def run(self): """ app = Flask(__name__) - app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER - app.config['SECRET_KEY'] = self.SECRET_KEY + app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER + app.config["SECRET_KEY"] = self.SECRET_KEY - @app.route('/') + @app.route("/") def index(): """ @@ -261,7 +289,7 @@ def index(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Render template not_configured_template = self.check_configured() @@ -269,29 +297,37 @@ def index(): template = not_configured_template else: events = self.control.get_events() - message = request.args.get('message', None) - message_type = request.args.get('message_type', None) - template = render_template('events.html', client=self.name, state=ReducerStateToString(self.control.state()), - events=events, - logs=None, refresh=True, configured=True, message=message, message_type=message_type) + message = request.args.get("message", None) + message_type = request.args.get("message_type", None) + template = render_template( + "events.html", + client=self.name, + state=ReducerStateToString(self.control.state()), + events=events, + logs=None, + refresh=True, + configured=True, + message=message, + message_type=message_type, + ) # Set token cookie in response if needed response = make_response(template) - if 'token' in request.args: # args auth - response.set_cookie('fedn_token', str(request.args['token'])) + if "token" in request.args: # args auth + response.set_cookie("fedn_token", str(request.args["token"])) # Return response return response - @app.route('/status') + @app.route("/status") def status(): """ :return: """ - return {'state': ReducerStateToString(self.control.state())} + return {"state": ReducerStateToString(self.control.state())} - @app.route('/netgraph') + @app.route("/netgraph") def netgraph(): """ Creates nodes and edges for network graph @@ -299,16 +335,18 @@ def netgraph(): :return: nodes and edges as keys :rtype: dict """ - result = {'nodes': [], 'edges': []} - - result['nodes'].append({ - "id": "reducer", - "label": "Reducer", - "role": 'reducer', - "status": 'active', - "name": 'reducer', # TODO: get real host name - "type": 'reducer', - }) + result = {"nodes": [], "edges": []} + + result["nodes"].append( + { + "id": "reducer", + "label": "Reducer", + "role": "reducer", + "status": "active", + "name": "reducer", # TODO: get real host name + "type": "reducer", + } + ) combiner_info = combiner_status() client_info = client_status() @@ -319,49 +357,55 @@ def netgraph(): for combiner in combiner_info: print("combiner info {}".format(combiner_info), flush=True) try: - result['nodes'].append({ - "id": combiner['name'], # "n{}".format(count), - "label": "Combiner ({} clients)".format(combiner['nr_active_clients']), - "role": 'combiner', - "status": 'active', # TODO: Hard-coded, combiner_info does not contain status - "name": combiner['name'], - "type": 'combiner', - }) + result["nodes"].append( + { + "id": combiner["name"], # "n{}".format(count), + "label": "Combiner ({} clients)".format( + combiner["nr_active_clients"] + ), + "role": "combiner", + "status": "active", # TODO: Hard-coded, combiner_info does not contain status + "name": combiner["name"], + "type": "combiner", + } + ) except Exception as err: print(err) - for client in client_info['active_clients']: + for client in client_info["active_clients"]: try: - if client['status'] != 'offline': - result['nodes'].append({ - "id": str(client['_id']), - "label": "Client", - "role": client['role'], - "status": client['status'], - "name": client['name'], - "combiner": client['combiner'], - "type": 'client', - }) + if client["status"] != "offline": + result["nodes"].append( + { + "id": str(client["_id"]), + "label": "Client", + "role": client["role"], + "status": client["status"], + "name": client["name"], + "combiner": client["combiner"], + "type": "client", + } + ) except Exception as err: print(err) count = 0 - for node in result['nodes']: + for node in result["nodes"]: try: - if node['type'] == 'combiner': - result['edges'].append( + if node["type"] == "combiner": + result["edges"].append( { "id": "e{}".format(count), - "source": node['id'], - "target": 'reducer', + "source": node["id"], + "target": "reducer", } ) - elif node['type'] == 'client': - result['edges'].append( + elif node["type"] == "client": + result["edges"].append( { "id": "e{}".format(count), - "source": node['combiner'], - "target": node['id'], + "source": node["combiner"], + "target": node["id"], } ) except Exception: @@ -369,59 +413,75 @@ def netgraph(): count = count + 1 return result - @app.route('/networkgraph') + @app.route("/networkgraph") def network_graph(): - try: plot = Plot(self.control.statestore) result = netgraph() - df_nodes = pd.DataFrame(result['nodes']) - df_edges = pd.DataFrame(result['edges']) + df_nodes = pd.DataFrame(result["nodes"]) + df_edges = pd.DataFrame(result["edges"]) graph = plot.make_netgraph_plot(df_edges, df_nodes) return json.dumps(json_item(graph, "myplot")) except Exception: raise # return '' - @app.route('/events') + @app.route("/events") def events(): """ :return: """ + response = self.control.get_events() + events = [] + + result = response["result"] + + for evt in result: + events.append(evt) + + return jsonify({"result": events, "count": response["count"]}) + json_docs = [] for doc in self.control.get_events(): json_doc = json.dumps(doc, default=json_util.default) json_docs.append(json_doc) json_docs.reverse() - return {'events': json_docs} - @app.route('/add') + return {"events": json_docs} + + @app.route("/add") def add(): - """ Add a combiner to the network. """ + """Add a combiner to the network.""" print("Adding combiner to network:", flush=True) if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) if self.control.state() == ReducerState.setup: - return jsonify({'status': 'retry'}) - - name = request.args.get('name', None) - address = str(request.args.get('address', None)) - fqdn = str(request.args.get('fqdn', None)) - port = request.args.get('port', None) - secure_grpc = request.args.get('secure', None) - - if port is None or address is None or name is None or secure_grpc is None: + return jsonify({"status": "retry"}) + + name = request.args.get("name", None) + address = str(request.args.get("address", None)) + fqdn = str(request.args.get("fqdn", None)) + port = request.args.get("port", None) + secure_grpc = request.args.get("secure", None) + + if ( + port is None + or address is None + or name is None + or secure_grpc is None + ): return "Please specify correct parameters." # Try to retrieve combiner from db combiner = self.control.network.get_combiner(name) if not combiner: - if secure_grpc == 'True': + if secure_grpc == "True": certificate, key = self.certificate_manager.get_or_create( - address).get_keypair_raw() + address + ).get_keypair_raw() _ = base64.b64encode(certificate) _ = base64.b64encode(key) @@ -437,23 +497,24 @@ def add(): port=port, certificate=copy.deepcopy(certificate), key=copy.deepcopy(key), - ip=request.remote_addr) + ip=request.remote_addr, + ) self.control.network.add_combiner(combiner) combiner = self.control.network.get_combiner(name) ret = { - 'status': 'added', - 'storage': self.control.statestore.get_storage_backend(), - 'statestore': self.control.statestore.get_config(), - 'certificate': combiner.get_certificate(), - 'key': combiner.get_key() + "status": "added", + "storage": self.control.statestore.get_storage_backend(), + "statestore": self.control.statestore.get_config(), + "certificate": combiner.get_certificate(), + "key": combiner.get_key(), } return jsonify(ret) - @app.route('/eula', methods=['GET', 'POST']) + @app.route("/eula", methods=["GET", "POST"]) def eula(): """ @@ -462,9 +523,9 @@ def eula(): for r in request.headers: print("header contains: {}".format(r), flush=True) - return render_template('eula.html', configured=True) + return render_template("eula.html", configured=True) - @app.route('/models', methods=['GET', 'POST']) + @app.route("/models", methods=["GET", "POST"]) def models(): """ @@ -472,13 +533,12 @@ def models(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) - if request.method == 'POST': + if request.method == "POST": # upload seed file - uploaded_seed = request.files['seed'] + uploaded_seed = request.files["seed"] if uploaded_seed: - a = BytesIO() a.seek(0, 0) uploaded_seed.seek(0) @@ -504,23 +564,31 @@ def models(): h_latest_model_id = self.statestore.get_latest_model() model_info = self.control.get_model_info() - return render_template('models.html', box_plot=box_plot, metrics=valid_metrics, h_latest_model_id=h_latest_model_id, seed=True, - model_info=model_info, configured=True) + return render_template( + "models.html", + box_plot=box_plot, + metrics=valid_metrics, + h_latest_model_id=h_latest_model_id, + seed=True, + model_info=model_info, + configured=True, + ) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/delete_model_trail', methods=['GET', 'POST']) + @app.route("/delete_model_trail", methods=["GET", "POST"]) def delete_model_trail(): """ :return: """ - if request.method == 'POST': - + if request.method == "POST": statestore_config = self.control.statestore.get_config() self.tracer = MongoTracer( - statestore_config['mongo_config'], statestore_config['network_id']) + statestore_config["mongo_config"], + statestore_config["network_id"], + ) try: self.control.drop_models() except Exception: @@ -528,28 +596,28 @@ def delete_model_trail(): # drop objects in minio self.control.delete_bucket_objects() - return redirect(url_for('models')) + return redirect(url_for("models")) seed = True - return redirect(url_for('models', seed=seed)) + return redirect(url_for("models", seed=seed)) - @app.route('/drop_control', methods=['GET', 'POST']) + @app.route("/drop_control", methods=["GET", "POST"]) def drop_control(): """ :return: """ - if request.method == 'POST': + if request.method == "POST": self.control.statestore.drop_control() - return redirect(url_for('control')) - return redirect(url_for('control')) + return redirect(url_for("control")) + return redirect(url_for("control")) # http://localhost:8090/control?rounds=4&model_id=879fa112-c861-4cb1-a25d-775153e5b548 - @app.route('/control', methods=['GET', 'POST']) + @app.route("/control", methods=["GET", "POST"]) def control(): - """ Main page for round control. Configure, start and stop training sessions. """ + """Main page for round control. Configure, start and stop training sessions.""" # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -560,60 +628,88 @@ def control(): if self.remote_compute_package: try: - self.current_compute_context = self.control.get_compute_package_name() + self.current_compute_context = ( + self.control.get_compute_package_name() + ) except Exception: self.current_compute_context = None else: self.current_compute_context = "None:Local" if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=state, refresh=refresh, message="Reducer is in monitoring state")) - - if request.method == 'POST': + url_for( + "index", + state=state, + refresh=refresh, + message="Reducer is in monitoring state", + ) + ) + + if request.method == "POST": # Get session configuration - round_timeout = float(request.form.get('timeout', 180)) - buffer_size = int(request.form.get('buffer_size', -1)) - rounds = int(request.form.get('rounds', 1)) - delete_models = request.form.get('delete_models', True) - task = (request.form.get('task', '')) - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + round_timeout = float(request.form.get("timeout", 180)) + buffer_size = int(request.form.get("buffer_size", -1)) + rounds = int(request.form.get("rounds", 1)) + delete_models = request.form.get("delete_models", True) + task = request.form.get("task", "") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # checking if there are enough clients connected to start! clients_available = 0 for combiner in self.control.network.get_combiners(): try: combiner_state = combiner.report() - nac = combiner_state['nr_active_clients'] + nac = combiner_state["nr_active_clients"] clients_available = clients_available + int(nac) except Exception: pass if clients_available < clients_required: - return redirect(url_for('index', state=state, - message="Not enough clients available to start rounds! " - "check combiner client capacity", - message_type='warning')) + return redirect( + url_for( + "index", + state=state, + message="Not enough clients available to start rounds! " + "check combiner client capacity", + message_type="warning", + ) + ) - validate = request.form.get('validate', False) - if validate == 'False': + validate = request.form.get("validate", False) + if validate == "False": validate = False - helper_type = request.form.get('helper', 'keras') + helper_type = request.form.get("helper", "keras") # self.control.statestore.set_framework(helper_type) latest_model_id = self.statestore.get_latest_model() - config = {'round_timeout': round_timeout, 'buffer_size': buffer_size, - 'model_id': latest_model_id, 'rounds': rounds, 'delete_models_storage': delete_models, - 'clients_required': clients_required, - 'clients_requested': clients_requested, 'task': task, - 'validate': validate, 'helper_type': helper_type} - - threading.Thread(target=self.control.session, - args=(config,)).start() + config = { + "round_timeout": round_timeout, + "buffer_size": buffer_size, + "model_id": latest_model_id, + "rounds": rounds, + "delete_models_storage": delete_models, + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": task, + "validate": validate, + "helper_type": helper_type, + } + + threading.Thread( + target=self.control.session, args=(config,) + ).start() - return redirect(url_for('index', state=state, refresh=refresh, message="Sent execution plan.", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=state, + refresh=refresh, + message="Sent execution plan.", + message_type="SUCCESS", + ) + ) else: seed_model_id = None @@ -624,42 +720,53 @@ def control(): except Exception: pass - return render_template('index.html', latest_model_id=latest_model_id, - compute_package=self.current_compute_context, - seed_model_id=seed_model_id, - helper=self.control.statestore.get_helper(), validate=True, configured=True) - - @app.route('/assign') + return render_template( + "index.html", + latest_model_id=latest_model_id, + compute_package=self.current_compute_context, + seed_model_id=seed_model_id, + helper=self.control.statestore.get_helper(), + validate=True, + configured=True, + ) + + @app.route("/assign") def assign(): - """Handle client assignment requests. """ + """Handle client assignment requests.""" if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) response = self.check_configured_response() if response: return response - name = request.args.get('name', None) - combiner_preferred = request.args.get('combiner', None) + name = request.args.get("name", None) + combiner_preferred = request.args.get("combiner", None) if combiner_preferred: - combiner = self.control.network.get_combiner(combiner_preferred) + combiner = self.control.network.get_combiner( + combiner_preferred + ) else: combiner = self.control.network.find_available_combiner() if combiner is None: - return jsonify({'status': 'retry', - 'package': self.package, - 'msg': "Failed to assign to a combiner, try again later."}) + return jsonify( + { + "status": "retry", + "package": self.package, + "msg": "Failed to assign to a combiner, try again later.", + } + ) client = { - 'name': name, - 'combiner_preferred': combiner_preferred, - 'combiner': combiner.name, - 'ip': request.remote_addr, - 'status': 'available' + "name": name, + "combiner_preferred": combiner_preferred, + "combiner": combiner.name, + "ip": request.remote_addr, + "status": "available", } # Add client to database @@ -668,25 +775,25 @@ def assign(): # Return connection information to client if combiner.certificate: cert_b64 = base64.b64encode(combiner.certificate) - cert = str(cert_b64).split('\'')[1] + cert = str(cert_b64).split("'")[1] else: cert = None response = { - 'status': 'assigned', - 'host': combiner.address, - 'fqdn': combiner.fqdn, - 'package': self.package, - 'ip': combiner.ip, - 'port': combiner.port, - 'certificate': cert, - 'model_type': self.control.statestore.get_helper() + "status": "assigned", + "host": combiner.address, + "fqdn": combiner.fqdn, + "package": self.package, + "ip": combiner.ip, + "port": combiner.port, + "certificate": cert, + "model_type": self.control.statestore.get_helper(), } return jsonify(response) def combiner_status(): - """ Get current status reports from all combiners registered in the network. + """Get current status reports from all combiners registered in the network. :return: """ @@ -711,67 +818,90 @@ def client_status(): all_active_validators = [] for client in combiner_info: - active_trainers_str = client['active_trainers'] - active_validators_str = client['active_validators'] + active_trainers_str = client["active_trainers"] + active_validators_str = client["active_validators"] active_trainers_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_trainers_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_trainers_str # noqa: W605 + ).replace( + "name:", " " + ) active_validators_str = re.sub( - '[^a-zA-Z0-9-:\n\.]', '', active_validators_str).replace('name:', ' ') # noqa: W605 + "[^a-zA-Z0-9-:\n\.]", "", active_validators_str # noqa: W605 + ).replace( + "name:", " " + ) all_active_trainers.extend( - ' '.join(active_trainers_str.split(" ")).split()) + " ".join(active_trainers_str.split(" ")).split() + ) all_active_validators.extend( - ' '.join(active_validators_str.split(" ")).split()) + " ".join(active_validators_str.split(" ")).split() + ) active_trainers_list = [ - client for client in client_info if client['name'] in all_active_trainers] + client + for client in client_info + if client["name"] in all_active_trainers + ] active_validators_list = [ - cl for cl in client_info if cl['name'] in all_active_validators] + cl + for cl in client_info + if cl["name"] in all_active_validators + ] all_clients = [cl for cl in client_info] for client in all_clients: - status = 'offline' - role = 'None' + status = "offline" + role = "None" self.control.network.update_client_data( - client, status, role) + client, status, role + ) - all_active_clients = active_validators_list + active_trainers_list + all_active_clients = ( + active_validators_list + active_trainers_list + ) for client in all_active_clients: - status = 'active' - if client in active_trainers_list and client in active_validators_list: - role = 'trainer-validator' + status = "active" + if ( + client in active_trainers_list + and client in active_validators_list + ): + role = "trainer-validator" elif client in active_trainers_list: - role = 'trainer' + role = "trainer" elif client in active_validators_list: - role = 'validator' + role = "validator" else: - role = 'unknown' + role = "unknown" self.control.network.update_client_data( - client, status, role) - - return {'active_clients': all_clients, - 'active_trainers': active_trainers_list, - 'active_validators': active_validators_list - } + client, status, role + ) + + return { + "active_clients": all_clients, + "active_trainers": active_trainers_list, + "active_validators": active_validators_list, + } except Exception: pass - return {'active_clients': [], - 'active_trainers': [], - 'active_validators': [] - } + return { + "active_clients": [], + "active_trainers": [], + "active_validators": [], + } - @app.route('/metric_type', methods=['GET', 'POST']) + @app.route("/metric_type", methods=["GET", "POST"]) def change_features(): """ :return: """ - feature = request.args['selected'] + feature = request.args["selected"] plot = Plot(self.control.statestore) graphJSON = plot.create_box_plot(feature) return graphJSON - @app.route('/dashboard') + @app.route("/dashboard") def dashboard(): """ @@ -779,7 +909,7 @@ def dashboard(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -793,16 +923,18 @@ def dashboard(): clients_plot = plot.create_client_plot() client_histogram_plot = plot.create_client_histogram_plot() - return render_template('dashboard.html', show_plot=True, - table_plot=table_plot, - timeline_plot=timeline_plot, - clients_plot=clients_plot, - client_histogram_plot=client_histogram_plot, - combiners_plot=combiners_plot, - configured=True - ) - - @app.route('/network') + return render_template( + "dashboard.html", + show_plot=True, + table_plot=table_plot, + timeline_plot=timeline_plot, + clients_plot=clients_plot, + client_histogram_plot=client_histogram_plot, + combiners_plot=combiners_plot, + configured=True, + ) + + @app.route("/network") def network(): """ @@ -810,7 +942,7 @@ def network(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) not_configured = self.check_configured() if not_configured: @@ -821,17 +953,19 @@ def network(): combiner_info = combiner_status() active_clients = client_status() # print(combiner_info, flush=True) - return render_template('network.html', network_plot=True, - round_time_plot=round_time_plot, - mem_cpu_plot=mem_cpu_plot, - combiner_info=combiner_info, - active_clients=active_clients['active_clients'], - active_trainers=active_clients['active_trainers'], - active_validators=active_clients['active_validators'], - configured=True - ) - - @app.route('/config/download', methods=['GET']) + return render_template( + "network.html", + network_plot=True, + round_time_plot=round_time_plot, + mem_cpu_plot=mem_cpu_plot, + combiner_info=combiner_info, + active_clients=active_clients["active_clients"], + active_trainers=active_clients["active_trainers"], + active_validators=active_clients["active_validators"], + configured=True, + ) + + @app.route("/config/download", methods=["GET"]) def config_download(): """ @@ -839,8 +973,8 @@ def config_download(): """ chk_string = "" name = self.control.get_compute_package_name() - if name is None or name == '': - chk_string = '' + if name is None or name == "": + chk_string = "" else: file_path = os.path.join(UPLOAD_FOLDER, name) print("trying to get {}".format(file_path)) @@ -848,7 +982,7 @@ def config_download(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" chk_string = "checksum: {}".format(sum) network_id = self.network_id @@ -857,20 +991,24 @@ def config_download(): ctx = """network_id: {network_id} discover_host: {discover_host} discover_port: {discover_port} -{chk_string}""".format(network_id=network_id, - discover_host=discover_host, - discover_port=discover_port, - chk_string=chk_string) +{chk_string}""".format( + network_id=network_id, + discover_host=discover_host, + discover_port=discover_port, + chk_string=chk_string, + ) obj = BytesIO() - obj.write(ctx.encode('UTF-8')) + obj.write(ctx.encode("UTF-8")) obj.seek(0) - return send_file(obj, - as_attachment=True, - download_name='client.yaml', - mimetype='application/x-yaml') - - @app.route('/context', methods=['GET', 'POST']) + return send_file( + obj, + as_attachment=True, + download_name="client.yaml", + mimetype="application/x-yaml", + ) + + @app.route("/context", methods=["GET", "POST"]) def context(): """ @@ -878,78 +1016,85 @@ def context(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # if reset is not empty then allow context re-set - reset = request.args.get('reset', None) + reset = request.args.get("reset", None) if reset: - return render_template('context.html') + return render_template("context.html") - if request.method == 'POST': + if request.method == "POST": + if "file" not in request.files: + flash("No file part") + return redirect(url_for("context")) - if 'file' not in request.files: - flash('No file part') - return redirect(url_for('context')) - - file = request.files['file'] - helper_type = request.form.get('helper', 'kerashelper') + file = request.files["file"] + helper_type = request.form.get("helper", "kerashelper") # if user does not select file, browser also # submit an empty part without filename - if file.filename == '': - flash('No selected file') - return redirect(url_for('context')) + if file.filename == "": + flash("No selected file") + return redirect(url_for("context")) if file and allowed_file(file.filename): filename = secure_filename(file.filename) file_path = os.path.join( - app.config['UPLOAD_FOLDER'], filename) + app.config["UPLOAD_FOLDER"], filename + ) file.save(file_path) - if self.control.state() == ReducerState.instructing or self.control.state() == ReducerState.monitoring: + if ( + self.control.state() == ReducerState.instructing + or self.control.state() == ReducerState.monitoring + ): return "Not allowed to change context while execution is ongoing." self.control.set_compute_package(filename, file_path) self.control.statestore.set_helper(helper_type) - return redirect(url_for('control')) + return redirect(url_for("control")) - name = request.args.get('name', '') + name = request.args.get("name", "") - if name == '': + if name == "": name = self.control.get_compute_package_name() - if name is None or name == '': - return render_template('context.html') + if name is None or name == "": + return render_template("context.html") # There is a potential race condition here, if one client requests a package and at # the same time another one triggers a fetch from Minio and writes to disk. try: mutex = Lock() mutex.acquire() - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: try: data = self.control.get_compute_package(name) - file_path = os.path.join(app.config['UPLOAD_FOLDER'], name) - with open(file_path, 'wb') as fh: + file_path = os.path.join(app.config["UPLOAD_FOLDER"], name) + with open(file_path, "wb") as fh: fh.write(data) - return send_from_directory(app.config['UPLOAD_FOLDER'], name, as_attachment=True) + return send_from_directory( + app.config["UPLOAD_FOLDER"], name, as_attachment=True + ) except Exception: raise finally: mutex.release() - return render_template('context.html') + return render_template("context.html") - @app.route('/checksum', methods=['GET', 'POST']) + @app.route("/checksum", methods=["GET", "POST"]) def checksum(): """ :return: """ # sum = '' - name = request.args.get('name', None) - if name == '' or name is None: + name = request.args.get("name", None) + if name == "" or name is None: name = self.control.get_compute_package_name() - if name is None or name == '': + if name is None or name == "": return jsonify({}) file_path = os.path.join(UPLOAD_FOLDER, name) @@ -958,13 +1103,13 @@ def checksum(): try: sum = str(sha(file_path)) except FileNotFoundError: - sum = '' + sum = "" - data = {'checksum': sum} + data = {"checksum": sum} return jsonify(data) - @app.route('/infer', methods=['POST']) + @app.route("/infer", methods=["POST"]) def infer(): """ @@ -972,7 +1117,7 @@ def infer(): """ # Token auth if self.token_auth_enabled: - self.authorize(request, app.config.get('SECRET_KEY')) + self.authorize(request, app.config.get("SECRET_KEY")) # Check configured not_configured = self.check_configured() @@ -982,7 +1127,9 @@ def infer(): # Check compute context if self.remote_compute_context: try: - self.current_compute_context = self.control.get_compute_package() + self.current_compute_context = ( + self.control.get_compute_package() + ) except Exception as e: print(e, flush=True) self.current_compute_context = None @@ -992,27 +1139,43 @@ def infer(): # Redirect if in monitoring state if self.control.state() == ReducerState.monitoring: return redirect( - url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Reducer is in monitoring state")) + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Reducer is in monitoring state", + ) + ) # POST params - timeout = int(request.form.get('timeout', 180)) - helper_type = request.form.get('helper', 'keras') - clients_required = request.form.get('clients_required', 1) - clients_requested = request.form.get('clients_requested', 8) + timeout = int(request.form.get("timeout", 180)) + helper_type = request.form.get("helper", "keras") + clients_required = request.form.get("clients_required", 1) + clients_requested = request.form.get("clients_requested", 8) # Start inference request - config = {'round_timeout': timeout, - 'model_id': self.statestore.get_latest_model(), - 'clients_required': clients_required, - 'clients_requested': clients_requested, - 'task': 'inference', - 'helper_type': helper_type} - threading.Thread(target=self.control.infer_instruct, - args=(config,)).start() + config = { + "round_timeout": timeout, + "model_id": self.statestore.get_latest_model(), + "clients_required": clients_required, + "clients_requested": clients_requested, + "task": "inference", + "helper_type": helper_type, + } + threading.Thread( + target=self.control.infer_instruct, args=(config,) + ).start() # Redirect - return redirect(url_for('index', state=ReducerStateToString(self.control.state()), refresh=True, message="Sent execution plan (inference).", - message_type='SUCCESS')) + return redirect( + url_for( + "index", + state=ReducerStateToString(self.control.state()), + refresh=True, + message="Sent execution plan (inference).", + message_type="SUCCESS", + ) + ) if not self.host: bind = "0.0.0.0" diff --git a/fedn/fedn/network/dashboard/templates/events.html b/fedn/fedn/network/dashboard/templates/events.html index d3c34beb5..1fb5fac74 100644 --- a/fedn/fedn/network/dashboard/templates/events.html +++ b/fedn/fedn/network/dashboard/templates/events.html @@ -3,41 +3,44 @@ {% block content %} -
-
-
Events
-
-
- - - - + + + -
- -
+ }); + +
+
+ -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/fedn/fedn/network/statestore/mongostatestore.py b/fedn/fedn/network/statestore/mongostatestore.py index f991701d4..19d514f59 100644 --- a/fedn/fedn/network/statestore/mongostatestore.py +++ b/fedn/fedn/network/statestore/mongostatestore.py @@ -10,7 +10,7 @@ class MongoStateStore(StateStoreBase): - """ Statestore implementation using MongoDB. + """Statestore implementation using MongoDB. :param network_id: The network id. :type network_id: str @@ -21,7 +21,7 @@ class MongoStateStore(StateStoreBase): """ def __init__(self, network_id, config, model_storage_config): - """ Constructor.""" + """Constructor.""" self.__inited = False try: self.config = config @@ -29,19 +29,19 @@ def __init__(self, network_id, config, model_storage_config): self.mdb = connect_to_mongodb(self.config, self.network_id) # FEDn network - self.network = self.mdb['network'] - self.reducer = self.network['reducer'] - self.combiners = self.network['combiners'] - self.clients = self.network['clients'] - self.storage = self.network['storage'] + self.network = self.mdb["network"] + self.reducer = self.network["reducer"] + self.combiners = self.network["combiners"] + self.clients = self.network["clients"] + self.storage = self.network["storage"] # Control - self.control = self.mdb['control'] - self.package = self.control['package'] - self.state = self.control['state'] - self.model = self.control['model'] - self.sessions = self.control['sessions'] - self.rounds = self.control['rounds'] + self.control = self.mdb["control"] + self.package = self.control["package"] + self.state = self.control["state"] + self.model = self.control["model"] + self.sessions = self.control["sessions"] + self.rounds = self.control["rounds"] # Logging self.status = self.control["status"] @@ -62,7 +62,7 @@ def __init__(self, network_id, config, model_storage_config): self.__inited = True def is_inited(self): - """ Check if the statestore is intialized. + """Check if the statestore is intialized. :return: True if initialized, else False. :rtype: bool @@ -76,105 +76,160 @@ def get_config(self): :rtype: dict """ data = { - 'type': 'MongoDB', - 'mongo_config': self.config, - 'network_id': self.network_id + "type": "MongoDB", + "mongo_config": self.config, + "network_id": self.network_id, } return data def state(self): - """ Get the current state. + """Get the current state. :return: The current state. :rtype: str """ - return StringToReducerState(self.state.find_one()['current_state']) + return StringToReducerState(self.state.find_one()["current_state"]) def transition(self, state): - """ Transition to a new state. + """Transition to a new state. :param state: The new state. :type state: str :return: """ - old_state = self.state.find_one({'state': 'current_state'}) + old_state = self.state.find_one({"state": "current_state"}) if old_state != state: - return self.state.update_one({'state': 'current_state'}, {'$set': {'state': ReducerStateToString(state)}}, True) + return self.state.update_one( + {"state": "current_state"}, + {"$set": {"state": ReducerStateToString(state)}}, + True, + ) else: - print("Not updating state, already in {}".format( - ReducerStateToString(state))) + print( + "Not updating state, already in {}".format( + ReducerStateToString(state) + ) + ) + + def get_sessions(self, limit=None, skip=None, sort_key="_id", sort_order=pymongo.DESCENDING): + """Get all sessions. + + :param limit: The maximum number of sessions to return. + :type limit: int + :param skip: The number of sessions to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: Dictionary of sessions in result (array of session objects) and count. + """ - def get_sessions(self): - """ Get all sessions. + result = None - :return: All sessions. - :rtype: ObjectID - """ - return self.sessions.find() + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = self.sessions.find().limit(limit).skip(skip).sort( + sort_key, sort_order + ) + else: + result = self.sessions.find().sort( + sort_key, sort_order + ) + + count = self.sessions.count_documents({}) + + return { + "result": result, + "count": count, + } def get_session(self, session_id): - """ Get session with id. + """Get session with id. :param session_id: The session id. :type session_id: str :return: The session. :rtype: ObjectID """ - return self.sessions.find_one({'session_id': session_id}) + return self.sessions.find_one({"session_id": session_id}) - def set_latest_model(self, model_id): - """ Set the latest model id. + def set_latest_model(self, model_id, session_id=None): + """Set the latest model id. :param model_id: The model id. :type model_id: str :return: """ - self.model.update_one({'key': 'current_model'}, { - '$set': {'model': model_id}}, True) - self.model.update_one({'key': 'model_trail'}, {'$push': {'model': model_id, 'committed_at': str(datetime.now())}}, - True) + committed_at = datetime.now() + + self.model.insert_one( + { + "key": "models", + "model": model_id, + "session_id": session_id, + "committed_at": committed_at, + } + ) + + self.model.update_one( + {"key": "current_model"}, {"$set": {"model": model_id}}, True + ) + self.model.update_one( + {"key": "model_trail"}, + { + "$push": { + "model": model_id, + "committed_at": str(committed_at), + } + }, + True, + ) def get_initial_model(self): - """ Return model_id for the initial model in the model trail + """Return model_id for the initial model in the model trail :return: The initial model id. None if no model is found. :rtype: str """ - result = self.model.find_one({'key': 'model_trail'}, sort=[ - ("committed_at", pymongo.ASCENDING)]) + result = self.model.find_one( + {"key": "model_trail"}, sort=[("committed_at", pymongo.ASCENDING)] + ) if result is None: return None try: - model_id = result['model'] - if model_id == '' or model_id == ' ': + model_id = result["model"] + if model_id == "" or model_id == " ": return None return model_id[0] except (KeyError, IndexError): return None def get_latest_model(self): - """ Return model_id for the latest model in the model_trail + """Return model_id for the latest model in the model_trail :return: The latest model id. None if no model is found. :rtype: str """ - result = self.model.find_one({'key': 'current_model'}) + result = self.model.find_one({"key": "current_model"}) if result is None: return None try: - model_id = result['model'] - if model_id == '' or model_id == ' ': + model_id = result["model"] + if model_id == "" or model_id == " ": return None return model_id except (KeyError, IndexError): return None def get_latest_round(self): - """ Get the id of the most recent round. + """Get the id of the most recent round. :return: The id of the most recent round. :rtype: ObjectId @@ -183,7 +238,7 @@ def get_latest_round(self): return self.rounds.find_one(sort=[("_id", pymongo.DESCENDING)]) def get_round(self, id): - """ Get round with id. + """Get round with id. :param id: id of round to get :type id: int @@ -191,10 +246,10 @@ def get_round(self, id): :rtype: ObjectId """ - return self.rounds.find_one({'round_id': str(id)}) + return self.rounds.find_one({"round_id": str(id)}) def get_rounds(self): - """ Get all rounds. + """Get all rounds. :return: All rounds. :rtype: ObjectId @@ -203,7 +258,7 @@ def get_rounds(self): return self.rounds.find() def get_validations(self, **kwargs): - """ Get validations from the database. + """Get validations from the database. :param kwargs: query to filter validations :type kwargs: dict @@ -215,7 +270,7 @@ def get_validations(self, **kwargs): return result def set_compute_package(self, filename): - """ Set the active compute package in statestore. + """Set the active compute package in statestore. :param filename: The filename of the compute package. :type filename: str @@ -223,66 +278,139 @@ def set_compute_package(self, filename): :rtype: bool """ self.control.package.update_one( - {'key': 'active'}, {'$set': {'filename': filename, 'committed_at': str(datetime.now())}}, True) - self.control.package.update_one({'key': 'package_trail'}, - {'$push': {'filename': filename, 'committed_at': str(datetime.now())}}, True) + {"key": "active"}, + { + "$set": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) + self.control.package.update_one( + {"key": "package_trail"}, + { + "$push": { + "filename": filename, + "committed_at": str(datetime.now()), + } + }, + True, + ) return True def get_compute_package(self): - """ Get the active compute package. + """Get the active compute package. :return: The active compute package. :rtype: ObjectID """ - ret = self.control.package.find({'key': 'active'}) + ret = self.control.package.find({"key": "active"}) try: retcheck = ret[0] - if retcheck is None or retcheck == '' or retcheck == ' ': # ugly check for empty string + if ( + retcheck is None or retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None def set_helper(self, helper): - """ Set the active helper package in statestore. + """Set the active helper package in statestore. :param helper: The name of the helper package. See helper.py for available helpers. :type helper: str :return: """ - self.control.package.update_one({'key': 'active'}, - {'$set': {'helper': helper}}, True) + self.control.package.update_one( + {"key": "active"}, {"$set": {"helper": helper}}, True + ) def get_helper(self): - """ Get the active helper package. + """Get the active helper package. :return: The active helper set for the package. :rtype: str """ - ret = self.control.package.find_one({'key': 'active'}) + ret = self.control.package.find_one({"key": "active"}) # if local compute package used, then 'package' is None # if not ret: # get framework from round_config instead # ret = self.control.config.find_one({'key': 'round_config'}) try: - retcheck = ret['helper'] - if retcheck == '' or retcheck == ' ': # ugly check for empty string + retcheck = ret["helper"] + if ( + retcheck == "" or retcheck == " " + ): # ugly check for empty string return None return retcheck except (KeyError, IndexError): return None + def list_models( + self, + session_id=None, + limit=None, + skip=None, + sort_key="committed_at", + sort_order=pymongo.DESCENDING, + ): + """List all models in the statestore. + + :param session_id: The session id. + :type session_id: str + :param limit: The maximum number of models to return. + :type limit: int + :param skip: The number of models to skip. + :type skip: int + :return: List of models. + :rtype: list + """ + result = None + + find_option = ( + {"key": "models"} + if session_id is None + else {"key": "models", "session_id": session_id} + ) + + projection = {"_id": False, "key": False} + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + + result = ( + self.model.find(find_option, projection) + .limit(limit) + .skip(skip) + .sort(sort_key, sort_order) + ) + + else: + result = self.model.find(find_option, projection).sort( + sort_key, sort_order + ) + + count = self.model.count_documents(find_option) + + return { + "result": result, + "count": count, + } + def get_model_trail(self): - """ Get the model trail. + """Get the model trail. :return: dictionary of model_id: committed_at :rtype: dict """ - result = self.model.find_one({'key': 'model_trail'}) + result = self.model.find_one({"key": "model_trail"}) try: if result is not None: - committed_at = result['committed_at'] - model = result['model'] + committed_at = result["committed_at"] + model = result["model"] model_dictionary = dict(zip(model, committed_at)) return model_dictionary else: @@ -291,7 +419,7 @@ def get_model_trail(self): return None def get_events(self, **kwargs): - """ Get events from the database. + """Get events from the database. :param kwargs: query to filter events :type kwargs: dict @@ -299,51 +427,83 @@ def get_events(self, **kwargs): :rtype: ObjectId """ # check if kwargs is empty + + result = None + count = None + projection = {"_id": False} + if not kwargs: - return self.control.status.find() + result = self.control.status.find({}, projection).sort( + "timestamp", pymongo.DESCENDING + ) + count = self.control.status.count_documents({}) else: - result = self.control.status.find(kwargs) - return result + limit = kwargs.pop("limit", None) + skip = kwargs.pop("skip", None) + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = ( + self.control.status.find(kwargs, projection) + .sort("timestamp", pymongo.DESCENDING) + .limit(limit) + .skip(skip) + ) + else: + result = self.control.status.find(kwargs, projection).sort( + "timestamp", pymongo.DESCENDING + ) + + count = self.control.status.count_documents(kwargs) + + return { + "result": result, + "count": count, + } def get_storage_backend(self): - """ Get the storage backend. + """Get the storage backend. :return: The storage backend. :rtype: ObjectID """ try: ret = self.storage.find( - {'status': 'enabled'}, projection={'_id': False}) + {"status": "enabled"}, projection={"_id": False} + ) return ret[0] except (KeyError, IndexError): return None def set_storage_backend(self, config): - """ Set the storage backend. + """Set the storage backend. :param config: The storage backend configuration. :type config: dict :return: """ config = copy.deepcopy(config) - config['updated_at'] = str(datetime.now()) - config['status'] = 'enabled' + config["updated_at"] = str(datetime.now()) + config["status"] = "enabled" self.storage.update_one( - {'storage_type': config['storage_type']}, {'$set': config}, True) + {"storage_type": config["storage_type"]}, {"$set": config}, True + ) def set_reducer(self, reducer_data): - """ Set the reducer in the statestore. + """Set the reducer in the statestore. :param reducer_data: dictionary of reducer config. :type reducer_data: dict :return: """ - reducer_data['updated_at'] = str(datetime.now()) - self.reducer.update_one({'name': reducer_data['name']}, { - '$set': reducer_data}, True) + reducer_data["updated_at"] = str(datetime.now()) + self.reducer.update_one( + {"name": reducer_data["name"]}, {"$set": reducer_data}, True + ) def get_reducer(self): - """ Get reducer.config. + """Get reducer.config. return: reducer config. rtype: ObjectId @@ -355,67 +515,99 @@ def get_reducer(self): return None def get_combiner(self, name): - """ Get combiner by name. + """Get combiner by name. + :param name: name of combiner to get. + :type name: str :return: The combiner. :rtype: ObjectId """ try: - ret = self.combiners.find_one({'name': name}) + ret = self.combiners.find_one({"name": name}) return ret except Exception: return None - def get_combiners(self): - """ Get all combiners. - - :return: list of combiners. - :rtype: list + def get_combiners(self, limit=None, skip=None, sort_key="updated_at", sort_order=pymongo.DESCENDING, projection={}): + """Get all combiners. + + :param limit: The maximum number of combiners to return. + :type limit: int + :param skip: The number of combiners to skip. + :type skip: int + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :param projection: The projection. + :type projection: dict + :return: Dictionary of combiners in result and count. + :rtype: dict """ + + result = None + count = None + try: - ret = self.combiners.find() - return list(ret) + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.combiners.find({}, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.combiners.find({}, projection).sort(sort_key, sort_order) + + count = self.combiners.count_documents({}) + except Exception: return None + return { + "result": result, + "count": count, + } + def set_combiner(self, combiner_data): - """ Set combiner in statestore. + """Set combiner in statestore. :param combiner_data: dictionary of combiner config :type combiner_data: dict :return: """ - combiner_data['updated_at'] = str(datetime.now()) - self.combiners.update_one({'name': combiner_data['name']}, { - '$set': combiner_data}, True) + combiner_data["updated_at"] = str(datetime.now()) + self.combiners.update_one( + {"name": combiner_data["name"]}, {"$set": combiner_data}, True + ) def delete_combiner(self, combiner): - """ Delete a combiner from statestore. + """Delete a combiner from statestore. :param combiner: name of combiner to delete. :type combiner: str :return: """ try: - self.combiners.delete_one({'name': combiner}) + self.combiners.delete_one({"name": combiner}) except Exception: - print("WARNING, failed to delete combiner: {}".format( - combiner), flush=True) + print( + "WARNING, failed to delete combiner: {}".format(combiner), + flush=True, + ) def set_client(self, client_data): - """ Set client in statestore. + """Set client in statestore. :param client_data: dictionary of client config. :type client_data: dict :return: """ - client_data['updated_at'] = str(datetime.now()) - self.clients.update_one({'name': client_data['name']}, { - '$set': client_data}, True) + client_data["updated_at"] = str(datetime.now()) + self.clients.update_one( + {"name": client_data["name"]}, {"$set": client_data}, True + ) def get_client(self, name): - """ Get client by name. + """Get client by name. :param name: name of client to get. :type name: str @@ -423,7 +615,7 @@ def get_client(self, name): :rtype: ObjectId """ try: - ret = self.clients.find({'key': name}) + ret = self.clients.find({"key": name}) if list(ret) == []: return None else: @@ -431,20 +623,77 @@ def get_client(self, name): except Exception: return None - def list_clients(self): + def list_clients(self, limit=None, skip=None, status=None, sort_key="last_seen", sort_order=pymongo.DESCENDING): """List all clients registered on the network. - :return: list of clients. + :param limit: The maximum number of clients to return. + :type limit: int + :param skip: The number of clients to skip. + :type skip: int + :param status: online | offline + :type status: str + :param sort_key: The key to sort by. + """ + + result = None + count = None + + try: + find = {} if status is None else {"status": status} + projection = {"_id": False, "updated_at": False} + + if limit is not None and skip is not None: + limit = int(limit) + skip = int(skip) + result = self.clients.find(find, projection).limit(limit).skip(skip).sort(sort_key, sort_order) + else: + result = self.clients.find(find, projection).sort(sort_key, sort_order) + + count = self.clients.count_documents(find) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return { + "result": result, + "count": count, + } + + def list_combiners_data(self, combiners, sort_key="count", sort_order=pymongo.DESCENDING): + """List all combiner data. + + :param combiners: list of combiners to get data for. + :type combiners: list + :param sort_key: The key to sort by. + :type sort_key: str + :param sort_order: The sort order. + :type sort_order: pymongo.ASCENDING or pymongo.DESCENDING + :return: list of combiner data. :rtype: list(ObjectId) """ + + result = None + try: - ret = self.clients.find() - return list(ret) - except Exception: - return None + + pipeline = [ + {"$match": {"combiner": {"$in": combiners}, "status": "online"}}, + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] if combiners is not None else [ + {"$group": {"_id": "$combiner", "count": {"$sum": 1}}}, + {"$sort": {sort_key: sort_order, "_id": pymongo.ASCENDING}} + ] + + result = self.clients.aggregate(pipeline) + + except Exception as e: + print("ERROR: {}".format(e), flush=True) + + return result def update_client_status(self, client_data, status, role): - """ Set or update client status. + """Set or update client status. :param client_data: dictionary of client config. :type client_data: dict @@ -454,10 +703,7 @@ def update_client_status(self, client_data, status, role): :type role: str :return: """ - self.clients.update_one({"name": client_data['name']}, - {"$set": - { - "status": status, - "role": role - } - }) + self.clients.update_one( + {"name": client_data["name"]}, + {"$set": {"status": status, "role": role}}, + ) From b6d782e05dbfc774c5dbe1c3722029f3fe656ae7 Mon Sep 17 00:00:00 2001 From: Niklas Date: Thu, 2 Nov 2023 16:28:11 +0100 Subject: [PATCH 6/8] Check if clients object has last_seen property --- fedn/fedn/network/api/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 3a20e3502..f4a2bd2d6 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -81,7 +81,7 @@ def get_clients(self, limit=None, skip=None, status=False): "combiner_preferred": element["combiner_preferred"], "ip": element["ip"], "status": element["status"], - "last_seen": element["last_seen"], + "last_seen": element["last_seen"] if element.has_key("last_seen") else "", } arr.append(obj) From 1166f317e1c405e8329e02aff54bc870e071b38e Mon Sep 17 00:00:00 2001 From: Niklas Date: Thu, 2 Nov 2023 16:31:44 +0100 Subject: [PATCH 7/8] has_key is deprecated --- fedn/fedn/network/api/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index f4a2bd2d6..61095e6ec 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -81,7 +81,7 @@ def get_clients(self, limit=None, skip=None, status=False): "combiner_preferred": element["combiner_preferred"], "ip": element["ip"], "status": element["status"], - "last_seen": element["last_seen"] if element.has_key("last_seen") else "", + "last_seen": element["last_seen"] if "last_seen" in element else "", } arr.append(obj) From ae06b846efe714810cb932a852021dab282c991a Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Thu, 9 Nov 2023 17:41:14 +0100 Subject: [PATCH 8/8] Feature/SK-521 | Global model not created if the combiner terminates based on timeout (#478) --- .ci/tests/examples/wait_for.py | 6 +- examples/mnist-keras/bin/build.sh | 2 +- fedn/fedn/common/storage/s3/miniorepo.py | 6 +- fedn/fedn/common/tracer/mongotracer.py | 45 +++- fedn/fedn/network/api/interface.py | 5 +- fedn/fedn/network/controller/control.py | 228 ++++++++++---------- fedn/fedn/network/controller/controlbase.py | 107 ++++++--- 7 files changed, 240 insertions(+), 159 deletions(-) diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index dc3345da0..ccd76859d 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -18,7 +18,7 @@ def _retry(try_func, **func_args): for _ in range(RETRIES): is_success = try_func(**func_args) if is_success: - _eprint('Sucess.') + _eprint('Success.') return True _eprint(f'Sleeping for {SLEEP}.') sleep(SLEEP) @@ -30,7 +30,7 @@ def _test_rounds(n_rounds): client = pymongo.MongoClient( "mongodb://fedn_admin:password@localhost:6534") collection = client['fedn-network']['control']['rounds'] - query = {'reducer.status': 'Success'} + query = {'status': 'Finished'} n = collection.count_documents(query) client.close() _eprint(f'Succeded rounds: {n}.') @@ -60,7 +60,7 @@ def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092 return count == n_nodes except Exception as e: - _eprint(f'Reques exception econuntered: {e}.') + _eprint(f'Request exception enconuntered: {e}.') return False diff --git a/examples/mnist-keras/bin/build.sh b/examples/mnist-keras/bin/build.sh index 18cdb5128..44eda61df 100755 --- a/examples/mnist-keras/bin/build.sh +++ b/examples/mnist-keras/bin/build.sh @@ -5,4 +5,4 @@ set -e client/entrypoint init_seed # Make compute package -tar -czvf package.tgz client \ No newline at end of file +tar -czvf package.tgz client diff --git a/fedn/fedn/common/storage/s3/miniorepo.py b/fedn/fedn/common/storage/s3/miniorepo.py index 9341704e6..154cea7e9 100644 --- a/fedn/fedn/common/storage/s3/miniorepo.py +++ b/fedn/fedn/common/storage/s3/miniorepo.py @@ -62,11 +62,13 @@ def __init__(self, config): self.create_bucket(self.bucket) def create_bucket(self, bucket_name): - """ + """ Create a new bucket. If bucket exists, do nothing. - :param bucket_name: + :param bucket_name: The name of the bucket + :type bucket_name: str """ found = self.client.bucket_exists(bucket_name) + if not found: try: self.client.make_bucket(bucket_name) diff --git a/fedn/fedn/common/tracer/mongotracer.py b/fedn/fedn/common/tracer/mongotracer.py index 0a3e28cdc..aa5c0810b 100644 --- a/fedn/fedn/common/tracer/mongotracer.py +++ b/fedn/fedn/common/tracer/mongotracer.py @@ -52,18 +52,26 @@ def drop_status(self): if self.status: self.status.drop() - def new_session(self, id=None): - """ Create a new session. """ + def create_session(self, id=None): + """ Create a new session. + + :param id: The ID of the created session. + :type id: uuid, str + + """ if not id: id = uuid.uuid4() data = {'session_id': str(id)} self.sessions.insert_one(data) - def new_round(self, id): - """ Create a new session. """ + def create_round(self, round_data): + """ Create a new round. - data = {'round_id': str(id)} - self.rounds.insert_one(data) + :param round_data: Dictionary with round data. + :type round_data: dict + """ + # TODO: Add check if round_id already exists + self.rounds.insert_one(round_data) def set_session_config(self, id, config): self.sessions.update_one({'session_id': str(id)}, { @@ -72,18 +80,35 @@ def set_session_config(self, id, config): def set_round_combiner_data(self, data): """ - :param round_meta: + :param data: The combiner data + :type data: dict """ self.rounds.update_one({'round_id': str(data['round_id'])}, { '$push': {'combiners': data}}, True) - def set_round_data(self, round_data): + def set_round_config(self, round_id, round_config): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_config': round_config}}, True) + + def set_round_status(self, round_id, round_status): + """ + + :param round_meta: + """ + self.rounds.update_one({'round_id': round_id}, { + '$set': {'status': round_status}}, True) + + def set_round_data(self, round_id, round_data): """ :param round_meta: """ - self.rounds.update_one({'round_id': str(round_data['round_id'])}, { - '$push': {'reducer': round_data}}, True) + self.rounds.update_one({'round_id': round_id}, { + '$set': {'round_data': round_data}}, True) def update_client_status(self, client_name, status): """ Update client status in statestore. diff --git a/fedn/fedn/network/api/interface.py b/fedn/fedn/network/api/interface.py index 61095e6ec..0821ed176 100644 --- a/fedn/fedn/network/api/interface.py +++ b/fedn/fedn/network/api/interface.py @@ -707,9 +707,8 @@ def get_round(self, round_id): if round_object is None: return jsonify({"success": False, "message": "Round not found."}) payload = { - "round_id": round_object["round_id"], - "reducer": round_object["reducer"], - "combiners": round_object["combiners"], + 'round_id': round_object['round_id'], + 'combiners': round_object['combiners'], } return jsonify(payload) diff --git a/fedn/fedn/network/controller/control.py b/fedn/fedn/network/controller/control.py index a8e32333d..615edb3b5 100644 --- a/fedn/fedn/network/controller/control.py +++ b/fedn/fedn/network/controller/control.py @@ -3,6 +3,9 @@ import time import uuid +from tenacity import (retry, retry_if_exception_type, stop_after_delay, + wait_random) + from fedn.network.combiner.interfaces import CombinerUnavailableError from fedn.network.controller.controlbase import ControlBase from fedn.network.state import ReducerState @@ -48,6 +51,20 @@ def __init__(self, message): super().__init__(self.message) +class CombinersNotDoneException(Exception): + """ Exception class for when model is None """ + + def __init__(self, message): + """ Constructor method. + + :param message: The exception message. + :type message: str + + """ + self.message = message + super().__init__(self.message) + + class Control(ControlBase): """Controller, implementing the overall global training, validation and inference logic. @@ -83,12 +100,10 @@ def session(self, config): return self._state = ReducerState.instructing - - # Must be called to set info in the db config["committed_at"] = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S" ) - self.new_session(config) + self.create_session(config) if not self.statestore.get_latest_model(): print( @@ -106,14 +121,13 @@ def session(self, config): # Execute the rounds in this session for round in range(1, int(config["rounds"] + 1)): # Increment the round number - if last_round: current_round = last_round + round else: current_round = round try: - _, round_data = self.round(config, current_round) + _, round_data = self.round(config, str(current_round)) except TypeError as e: print( "Could not unpack data from round: {0}".format(e), @@ -127,30 +141,27 @@ def session(self, config): flush=True, ) - self.tracer.set_round_data(round_data) - # TODO: Report completion of session self._state = ReducerState.idle def round(self, session_config, round_id): - """Execute a single global round. + """ Execute one global round. + + : param session_config: The session config. + : type session_config: dict + : param round_id: The round id. + : type round_id: str - :param session_config: The session config. - :type session_config: dict - :param round_id: The round id. - :type round_id: str(int) """ - round_data = {"round_id": round_id} + self.create_round({'round_id': round_id, 'status': "Pending"}) if len(self.network.get_combiners()) < 1: - print("REDUCER: No combiners connected!", flush=True) - round_data["status"] = "Failed" - return None, round_data + print("CONTROLLER: Round cannot start, no combiners connected!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 1. Assemble round config for this global round, - # and check which combiners are able to participate - # in the round. + # Assemble round config for this global round round_config = copy.deepcopy(session_config) round_config["rounds"] = 1 round_config["round_id"] = round_id @@ -158,94 +169,85 @@ def round(self, session_config, round_id): round_config["model_id"] = self.statestore.get_latest_model() round_config["helper_type"] = self.statestore.get_helper() - combiners = self.get_participating_combiners(round_config) - round_start = self.evaluate_round_start_policy(combiners) + self.set_round_config(round_id, round_config) + + # Get combiners that are able to participate in round, given round_config + participating_combiners = self.get_participating_combiners(round_config) + + # Check if the policy to start the round is met + round_start = self.evaluate_round_start_policy(participating_combiners) if round_start: - print( - "CONTROL: round start policy met, participating combiners {}".format( - combiners - ), - flush=True, - ) + print("CONTROL: round start policy met, {} participating combiners.".format( + len(participating_combiners)), flush=True) else: - print( - "CONTROL: Round start policy not met, skipping round!", - flush=True, - ) - round_data["status"] = "Failed" - return None + print("CONTROL: Round start policy not met, skipping round!", flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) + + # Ask participating combiners to coordinate model updates + _ = self.request_model_updates(participating_combiners) + # TODO: Check response + + # Wait until participating combiners have produced an updated global model, + # or round times out. + def do_if_round_times_out(result): + print("CONTROL: Round timed out!", flush=True) + + @retry(wait=wait_random(min=1.0, max=2.0), + stop=stop_after_delay(session_config['round_timeout']), + retry_error_callback=do_if_round_times_out, + retry=retry_if_exception_type(CombinersNotDoneException)) + def combiners_done(): - round_data["round_config"] = round_config - - # 2. Ask participating combiners to coordinate model updates - _ = self.request_model_updates(combiners) - - # Wait until participating combiners have produced an updated global model. - wait = 0.0 - # dict to store combiners that have successfully produced an updated model - updated = {} - # wait until all combiners have produced an updated model or until round timeout - print( - "CONTROL: Fetching round config (ID: {round_id}) from statestore:".format( - round_id=round_id - ), - flush=True, - ) - while len(updated) < len(combiners): round = self.statestore.get_round(round_id) - if round: - print("CONTROL: Round found!", flush=True) - # For each combiner in the round, check if it has produced an updated model (status == 'Success') - for combiner in round["combiners"]: - print(combiner, flush=True) - if combiner["status"] == "Success": - if combiner["name"] not in updated.keys(): - # Add combiner to updated dict - updated[combiner["name"]] = combiner["model_id"] - # Print combiner status - print( - "CONTROL: Combiner {name} status: {status}".format( - name=combiner["name"], status=combiner["status"] - ), - flush=True, - ) - else: - # Print every 10 seconds based on value of wait - if wait % 10 == 0: - print( - "CONTROL: Waiting for round to complete...", flush=True - ) - if wait >= session_config["round_timeout"]: - print("CONTROL: Round timeout! Exiting round...", flush=True) - break - # Update wait time used for timeout - time.sleep(1.0) - wait += 1.0 - - round_valid = self.evaluate_round_validity_policy(updated) + if 'combiners' not in round: + # TODO: use logger + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("Combiners have not yet reported.") + + if len(round['combiners']) < len(participating_combiners): + print("CONTROL: Waiting for combiners to update model...", flush=True) + raise CombinersNotDoneException("All combiners have not yet reported.") + + return True + + combiners_done() + + # Due to the distributed nature of the computation, there might be a + # delay before combiners have reported the round data to the db, + # so we need some robustness here. + @retry(wait=wait_random(min=0.1, max=1.0), + retry=retry_if_exception_type(KeyError)) + def check_combiners_done_reporting(): + round = self.statestore.get_round(round_id) + combiners = round['combiners'] + return combiners + + _ = check_combiners_done_reporting() + + round = self.statestore.get_round(round_id) + round_valid = self.evaluate_round_validity_policy(round) if not round_valid: print("REDUCER CONTROL: Round invalid!", flush=True) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - print("CONTROL: Reducing models from combiners...", flush=True) - # 3. Reduce combiner models into a global model + print("CONTROL: Reducing combiner level models...", flush=True) + # Reduce combiner models into a new global model + round_data = {} try: - model, data = self.reduce(updated) - round_data["reduce"] = data + round = self.statestore.get_round(round_id) + model, data = self.reduce(round['combiners']) + round_data['reduce'] = data print("CONTROL: Done reducing models from combiners!", flush=True) except Exception as e: - print( - "CONTROL: Failed to reduce models from combiners: {}".format( - e - ), - flush=True, - ) - round_data["status"] = "Failed" - return None, round_data + print("CONTROL: Failed to reduce models from combiners: {}".format( + e), flush=True) + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - # 6. Commit the global model to model trail + # Commit the new global model to the model trail if model is not None: print( "CONTROL: Committing global model to model trail...", @@ -271,10 +273,10 @@ def round(self, session_config, round_id): ), flush=True, ) - round_data["status"] = "Failed" - return None, round_data + self.set_round_status(round_id, 'Failed') + return None, self.statestore.get_round(round_id) - round_data["status"] = "Success" + self.set_round_status(round_id, 'Success') # 4. Trigger participating combiner nodes to execute a validation round for the current model validate = session_config["validate"] @@ -285,9 +287,8 @@ def round(self, session_config, round_id): combiner_config["task"] = "validation" combiner_config["helper_type"] = self.statestore.get_helper() - validating_combiners = self._select_participating_combiners( - combiner_config - ) + validating_combiners = self.get_participating_combiners( + combiner_config) for combiner, combiner_config in validating_combiners: try: @@ -302,13 +303,15 @@ def round(self, session_config, round_id): self._handle_unavailable_combiner(combiner) pass - return model_id, round_data + self.set_round_data(round_id, round_data) + self.set_round_status(round_id, 'Finished') + return model_id, self.statestore.get_round(round_id) def reduce(self, combiners): """Combine updated models from Combiner nodes into one global model. - :param combiners: dict of combiner names (key) and model IDs (value) to reduce - :type combiners: dict + : param combiners: dict of combiner names(key) and model IDs(value) to reduce + : type combiners: dict """ meta = {} @@ -323,7 +326,9 @@ def reduce(self, combiners): print("REDUCER: No combiners to reduce!", flush=True) return model, meta - for name, model_id in combiners.items(): + for combiner in combiners: + name = combiner['name'] + model_id = combiner['model_id'] # TODO: Handle inactive RPC error in get_model and raise specific error print( "REDUCER: Fetching model ({model_id}) from combiner {name}".format( @@ -333,9 +338,9 @@ def reduce(self, combiners): ) try: tic = time.time() - combiner = self.get_combiner(name) - data = combiner.get_model(model_id) - meta["time_fetch_model"] += time.time() - tic + combiner_interface = self.get_combiner(name) + data = combiner_interface.get_model(model_id) + meta['time_fetch_model'] += (time.time() - tic) except Exception as e: print( "REDUCER: Failed to fetch model from combiner {}: {}".format( @@ -367,7 +372,7 @@ def reduce(self, combiners): def infer_instruct(self, config): """Main entrypoint for executing the inference compute plan. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Check/set instucting state @@ -395,7 +400,7 @@ def infer_instruct(self, config): def inference_round(self, config): """Execute an inference round. - :param config: configuration for the inference round + : param config: configuration for the inference round """ # Init meta @@ -413,7 +418,8 @@ def inference_round(self, config): combiner_config["helper_type"] = self.statestore.get_framework() # Select combiners - validating_combiners = self._select_round_combiners(combiner_config) + validating_combiners = self.get_participating_combiners( + combiner_config) # Test round start policy round_start = self.check_round_start_policy(validating_combiners) diff --git a/fedn/fedn/network/controller/controlbase.py b/fedn/fedn/network/controller/controlbase.py index 077620c14..fab6a2027 100644 --- a/fedn/fedn/network/controller/controlbase.py +++ b/fedn/fedn/network/controller/controlbase.py @@ -196,8 +196,8 @@ def get_compute_package(self, compute_package=""): else: return None - def new_session(self, config): - """Initialize a new session in backend db.""" + def create_session(self, config): + """ Initialize a new session in backend db. """ if "session_id" not in config.keys(): session_id = uuid.uuid4() @@ -205,11 +205,50 @@ def new_session(self, config): else: session_id = config["session_id"] - self.tracer.new_session(id=session_id) + self.tracer.create_session(id=session_id) self.tracer.set_session_config(session_id, config) + def create_round(self, round_data): + """Initialize a new round in backend db. """ + + self.tracer.create_round(round_data) + + def set_round_data(self, round_id, round_data): + """ Set round data. + + :param round_id: The round unique identifier + :type round_id: str + :param round_data: The status + :type status: dict + """ + self.tracer.set_round_data(round_id, round_data) + + def set_round_status(self, round_id, status): + """ Set the round round stats. + + :param round_id: The round unique identifier + :type round_id: str + :param status: The status + :type status: str + """ + self.tracer.set_round_status(round_id, status) + + def set_round_config(self, round_id, round_config): + """ Upate round in backend db. + + :param round_id: The round unique identifier + :type round_id: str + :param round_config: The round configuration + :type round_config: dict + """ + self.tracer.set_round_config(round_id, round_config) + def request_model_updates(self, combiners): - """Call Combiner server RPC to get a model update.""" + """Ask Combiner server to produce a model update. + + :param combiners: A list of combiners + :type combiners: tuple (combiner, comboner_round_config) + """ cl = [] for combiner, combiner_round_config in combiners: response = combiner.submit(combiner_round_config) @@ -217,7 +256,15 @@ def request_model_updates(self, combiners): return cl def commit(self, model_id, model=None, session_id=None): - """Commit a model to the global model trail. The model commited becomes the lastest consensus model.""" + """Commit a model to the global model trail. The model commited becomes the lastest consensus model. + + :param model_id: Unique identifier for the model to commit. + :type model_id: str (uuid) + :param model: The model object to commit + :type model: BytesIO + :param session_id: Unique identifier for the session + :type session_id: str + """ helper = self.get_helper() if model is not None: @@ -289,45 +336,47 @@ def evaluate_round_participation_policy( return False def evaluate_round_start_policy(self, combiners): - """Check if the policy to start a round is met.""" + """Check if the policy to start a round is met. + + :param combiners: A list of combiners + :type combiners: list + :return: True if the round policy is mer, otherwise False + :rtype: bool + """ if len(combiners) > 0: return True else: return False - def evaluate_round_validity_policy(self, combiners): - """Check if the round should be seen as valid. + def evaluate_round_validity_policy(self, round): + """ Check if the round is valid. At the end of the round, before committing a model to the global model trail, we check if the round validity policy has been met. This can involve e.g. asserting that a certain number of combiners have reported in an updated model, or that criteria on model performance have been met. - """ - if combiners.keys() == []: - return False - else: - return True - def _select_participating_combiners(self, compute_plan): - participating_combiners = [] - for combiner in self.network.get_combiners(): + :param round: The round object + :rtype round: dict + :return: True if the policy is met, otherwise False + :rtype: bool + """ + model_ids = [] + for combiner in round['combiners']: try: - combiner_state = combiner.report() - except CombinerUnavailableError: - self._handle_unavailable_combiner(combiner) - combiner_state = None + model_ids.append(combiner['model_id']) + except KeyError: + pass - if combiner_state: - is_participating = self.evaluate_round_participation_policy( - compute_plan, combiner_state - ) - if is_participating: - participating_combiners.append((combiner, compute_plan)) - return participating_combiners + if len(model_ids) == 0: + return False + + return True def state(self): - """ + """ Get the current state of the controller - :return: + :return: The state + :rype: str """ return self._state