Skip to content

Commit

Permalink
Make it possible to configure the aggregator per session
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 2, 2024
1 parent 353dcf4 commit e01eb8e
Show file tree
Hide file tree
Showing 14 changed files with 844 additions and 529 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ services:
- ${HOST_REPO_DIR:-.}/fedn:/app/fedn
entrypoint: [ "sh", "-c" ]
command:
- "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run combiner -a fedopt --init config/settings-combiner.yaml"
- "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run combiner --init config/settings-combiner.yaml"
ports:
- 12080:12080

Expand Down
29 changes: 19 additions & 10 deletions examples/notebooks/API_Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,24 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"id": "f0380d35",
"metadata": {},
"outputs": [],
"source": [
"session_config = {\n",
" \"helper\": \"pytorchhelper\",\n",
" \"session_id\": str(uuid.uuid4()) \n",
" \"session_id\": \"session_fedavg\",\n",
" \"aggregator\": \"fedavg\"\n",
" }\n",
"\n",
"result = client.start_session(**session_config)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d92923c7",
"execution_count": 6,
"id": "acf65237",
"metadata": {
"scrolled": true
},
Expand All @@ -114,15 +115,15 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"id": "f4968b3a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OrderedDict([('9979d294-bf54-4d8d-b7d6-be499859b897', [0.367333322763443, 0.36899998784065247]), ('d4a611f6-bedb-4cc0-81c8-c6a3df909285', [0.4584999978542328, 0.4438333213329315]), ('5c74bdf8-f29d-4adc-bdf5-9a38ac1e2543', [0.5538333058357239, 0.5398333072662354]), ('5a7f64ca-6320-49a7-81f7-e1e09b5a75af', [0.715666651725769, 0.7070000171661377]), ('97c7f17c-6e59-4fe4-9ae8-b1bd1eb314d6', [0.7738333344459534, 0.765999972820282])])\n"
"OrderedDict([('9069d8eb-d009-4d27-806f-c536791d931a', [0.367333322763443, 0.36899998784065247]), ('3bc70b11-5634-492f-8fa8-e26f161a0a25', [0.4584999978542328, 0.4438333213329315]), ('235b9e7a-1fa8-4c98-ba13-b4bc2249eae4', [0.5398333072662354, 0.5538333058357239]), ('fa88366a-3f74-4bac-8ec6-870e4e947617', [0.715666651725769, 0.7070000171661377]), ('45706812-a1e1-40cc-9c39-b493c82e8ddb', [0.7738333344459534, 0.765999972820282])])\n"
]
}
],
Expand All @@ -141,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 8,
"id": "900eb0a7",
"metadata": {},
"outputs": [],
Expand All @@ -153,17 +154,17 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 9,
"id": "d064aaf9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x114ea7e20>]"
"[<matplotlib.lines.Line2D at 0x123893c10>]"
]
},
"execution_count": 14,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -181,6 +182,14 @@
"source": [
"plt.plot(mean_acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "597d47f1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
5 changes: 2 additions & 3 deletions fedn/cli/run_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,8 @@ def dashboard_cmd(ctx, host, port, secret_key, local_package, name, init):
@click.option('-c', '--max_clients', required=False, default=30, help='The maximal number of client connections allowed.')
@click.option('-in', '--init', required=False, default=None,
help='Path to configuration file to (re)init combiner.')
@click.option('-a', '--aggregator', required=False, default='fedavg', help='Filename of the aggregator module to use.')
@click.pass_context
def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init, aggregator):
def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn, secure, verify, max_clients, init):
"""
:param ctx:
Expand All @@ -269,7 +268,7 @@ def combiner_cmd(ctx, discoverhost, discoverport, token, name, host, port, fqdn,
"""
config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'host': host,
'port': port, 'fqdn': fqdn, 'name': name, 'secure': secure, 'verify': verify, 'max_clients': max_clients,
'init': init, 'aggregator': aggregator}
'init': init}

if config['init']:
apply_config(config)
Expand Down
3 changes: 2 additions & 1 deletion fedn/fedn/network/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_round(self, round_id):
response = requests.get(self._get_url(f'get_round?round_id={round_id}'), verify=self.verify)
return response.json()

def start_session(self, session_id=None, round_timeout=180, rounds=5, round_buffer_size=-1, delete_models=True,
def start_session(self, session_id=None, aggregator='fedavg', round_timeout=180, rounds=5, round_buffer_size=-1, delete_models=True,
validate=True, helper='kerashelper', min_clients=1, requested_clients=8):
""" Start a new session.
Expand All @@ -136,6 +136,7 @@ def start_session(self, session_id=None, round_timeout=180, rounds=5, round_buff
:rtype: dict
"""
response = requests.post(self._get_url('start_session'), json={
'aggregator': aggregator,
'session_id': session_id,
'round_timeout': round_timeout,
'rounds': rounds,
Expand Down
2 changes: 2 additions & 0 deletions fedn/fedn/network/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def list_combiners_data(self, combiners):
def start_session(
self,
session_id,
aggregator='fedavg',
rounds=5,
round_timeout=180,
round_buffer_size=-1,
Expand Down Expand Up @@ -836,6 +837,7 @@ def start_session(
session_config = {
"session_id": session_id if session_id else str(uuid.uuid4()),
"round_timeout": round_timeout,
"aggregator": aggregator,
"buffer_size": round_buffer_size,
"model_id": model_id,
"rounds": rounds,
Expand Down
2 changes: 1 addition & 1 deletion fedn/fedn/network/combiner/aggregators/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, storage, server, modelservice, control):

super().__init__(storage, server, modelservice, control)

self.name = "fedavg"
self.name = "fedopt"
# Server side hyperparameters
self.eta = 1
self.beta1 = 0.9
Expand Down
26 changes: 25 additions & 1 deletion fedn/fedn/network/combiner/combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self, config):
self.server = Server(self, self.modelservice, grpc_config)

# Set up round controller
self.control = RoundController(config['aggregator'], self.repository, self, self.modelservice)
self.control = RoundController(self.repository, self, self.modelservice)

# Start thread for round controller
threading.Thread(target=self.control.run, daemon=True).start()
Expand Down Expand Up @@ -428,6 +428,30 @@ def Start(self, control: fedn.ControlRequest, context):

return response

def SetAggregator(self, control: fedn.ControlRequest, context):
""" Set the active aggregator.
:param control: the control request
:type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the control response
:rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse`
"""
logger.debug("grpc.Combiner.SetAggregator: Called")
for parameter in control.parameter:
aggregator = parameter.value

status = self.control.set_aggregator(aggregator)

response = fedn.ControlResponse()
if status:
response.message = 'Success'
else:
response.message = 'Failed'

return response

def FlushAggregationQueue(self, control: fedn.ControlRequest, context):
""" Flush the queue.
Expand Down
24 changes: 24 additions & 0 deletions fedn/fedn/network/combiner/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ def flush_model_update_queue(self):
else:
raise

def set_aggregator(self, aggregator):
""" Set the active aggregator module.
:param aggregator: The name of the aggregator module.
:type config: str
"""

channel = Channel(self.address, self.port,
self.certificate).get_channel()
control = rpc.ControlStub(channel)

request = fedn.ControlRequest()
p = request.parameter.add()
p.key = "aggregator"
p.value = aggregator

try:
control.SetAggregator(request)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
raise CombinerUnavailableError
else:
raise

def submit(self, config):
""" Submit a compute plan to the combiner.
Expand Down
6 changes: 4 additions & 2 deletions fedn/fedn/network/combiner/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ class RoundController:
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
"""

def __init__(self, aggregator_name, storage, server, modelservice):
def __init__(self, storage, server, modelservice):
""" Initialize the RoundController."""

self.round_configs = queue.Queue()
self.storage = storage
self.server = server
self.modelservice = modelservice
self.aggregator = get_aggregator(aggregator_name, self.storage, self.server, self.modelservice, self)

def set_aggregator(self, aggregator):
self.aggregator = get_aggregator(aggregator, self.storage, self.server, self.modelservice, self)

def push_round_config(self, round_config):
"""Add a round_config (job description) to the inbox.
Expand Down
3 changes: 2 additions & 1 deletion fedn/fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def session(self, config):

# Clear potential stragglers/old model updates at combiners
for combiner in self.network.get_combiners():
combiner.flush_model_update_queue()
# combiner.flush_model_update_queue()
combiner.set_aggregator(config['aggregator'])

# Execute the rounds in this session
for round in range(1, int(config["rounds"] + 1)):
Expand Down
1 change: 1 addition & 0 deletions fedn/fedn/network/grpc/fedn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ service Control {
rpc Start(ControlRequest) returns (ControlResponse);
rpc Stop(ControlRequest) returns (ControlResponse);
rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse);
rpc SetAggregator(ControlRequest) returns (ControlResponse);
}

service Reducer {
Expand Down
Loading

0 comments on commit e01eb8e

Please sign in to comment.