Skip to content

Commit

Permalink
fedadam working for pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 19, 2024
1 parent be5051b commit e8238b0
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 57 deletions.
6 changes: 6 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ services:
- "/venv/bin/pip install --no-cache-dir -e /app/fedn && /venv/bin/fedn run combiner --init config/settings-combiner.yaml"
ports:
- 12080:12080
depends_on:
- api-server
healthcheck:
test: [ "/bin/bash" ]
start_period: 10s

# Client
client:
Expand All @@ -137,4 +142,5 @@ services:
deploy:
replicas: 0
depends_on:
- api-server
- combiner
105 changes: 69 additions & 36 deletions examples/notebooks/API_Example.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions fedn/fedn/network/api/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import requests

__all__ = ['APIClient']
Expand Down
44 changes: 33 additions & 11 deletions fedn/fedn/network/combiner/aggregators/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ def __init__(self, storage, server, modelservice, control):
super().__init__(storage, server, modelservice, control)

self.name = "fedopt"
self.v = None
self.m = None

# Server side hyperparameters
self.eta = 1
self.eta = 0.1
self.beta1 = 0.9
self.beta2 = 0.99
self.tau = 1e-3
self.tau = 1e-4

def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete_models=True):
"""Compute pseudo gradients usigng model updates in the queue.
Expand Down Expand Up @@ -65,9 +68,6 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

# v = math.pow(self.tau, 2)
# m = 0.0

while not self.model_updates.empty():
try:
# Get next model from queue
Expand Down Expand Up @@ -101,14 +101,36 @@ def combine_models(self, helper=None, time_window=180, max_nr_models=100, delete
"AGGREGATOR({}): Error encoutered while processing model update {}, skipping this update.".format(self.name, e))
self.model_updates.task_done()

# Server-side momentum
# m = helper.add(m, pseudo_gradient, self.beta1, (1.0-self.beta1))
# v = v + helper.power(pseudo_gradient, 2)
# model = model_old + self.eta*m/helper.sqrt(v+self.tau)

model = helper.add(model_old, pseudo_gradient, 1.0, self.eta)
model = self.serveropt_adam(helper, pseudo_gradient, model_old)

data['nr_aggregated_models'] = nr_aggregated_models

logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models))
return model, data

def serveropt_adam(self, helper, pseudo_gradient, model_old):
""" Server side optimization, FedAdam.
:param helper: instance of helper class.
:type helper: Helper
:param pseudo_gradient: The pseudo gradient.
:type pseudo_gradient: As defined by helper.
:return: new model weights.
:rtype: as defined by helper.
"""

if not self.v:
self.v = helper.ones(pseudo_gradient, math.pow(self.tau, 2))

if not self.m:
self.m = helper.multiply(pseudo_gradient, (1.0-self.beta1))
else:
self.m = helper.add(self.m, pseudo_gradient, self.beta1, (1.0-self.beta1))

p = helper.power(pseudo_gradient, 2)
self.v = helper.add(self.v, p, self.beta2, (1.0-self.beta2))
sv = helper.add(helper.sqrt(self.v), helper.ones(self.v, self.tau))
t = helper.divide(self.m, sv)

model = helper.add(model_old, t, 1.0, self.eta)
return model
3 changes: 1 addition & 2 deletions fedn/fedn/network/controller/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,10 @@ def round(self, session_config, round_id):
round_config["rounds"] = 1
round_config["round_id"] = round_id
round_config["task"] = "training"
#round_config["helper_type"] = self.statestore.get_helper()

self.set_round_config(round_id, round_config)

# Get combiners that are able to participate in round, given round_config
# Get combiners that are able to participate in the round, given round_config
participating_combiners = self.get_participating_combiners(round_config)

# Check if the policy to start the round is met
Expand Down
86 changes: 78 additions & 8 deletions fedn/fedn/utils/plugins/pytorchhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def increment_average(self, model, model_next, num_examples, total_examples):
return w

def add(self, m1, m2, a=1.0, b=1.0):
""" Add weights.
""" m1*a + m2*b
:param model: Current model weights with keys from torch state_dict.
:type model: OrderedDict
Expand All @@ -50,28 +50,98 @@ def add(self, m1, m2, a=1.0, b=1.0):
return w

def subtract(self, m1, m2, a=1.0, b=1.0):
""" m1*a - m2*b.
:param m1: Current model weights with keys from torch state_dict.
:type m1: OrderedDict
:param m2: New model weights with keys from torch state_dict.
:type m2: OrderedDict
:return: m1*a-m2*b
:rtype: OrderedDict
"""
return self.add(m1, m2, a, -b)

def divide(self, m1, m2):
""" Subtract weights.
:param model: Current model weights with keys from torch state_dict.
:param m1: Current model weights with keys from torch state_dict.
:type m1: OrderedDict
:param m2: New model weights with keys from torch state_dict.
:type m2: OrderedDict
:return: m1/m2.
:rtype: OrderedDict
"""

res = OrderedDict()
for key, val in m1.items():
res[key] = np.divide(val, m2[key])

return res

def multiply(self, m1, m2):
""" Multiply m1 by m2.
:param m1: Current model weights with keys from torch state_dict.
:type m1: OrderedDict
:param m2: New model weights with keys from torch state_dict.
:type m2: OrderedDict
:return: m1.*m2
:rtype: OrderedDict
"""

res = OrderedDict()
for key, val in m1.items():
res[key] = np.multiply(np.array(val), m2)

return res

def sqrt(self, m1):
""" Sqrt of m1, element-wise.
:param m1: Current model weights with keys from torch state_dict.
:type model: OrderedDict
:param model_next: New model weights with keys from torch state_dict.
:type model_next: OrderedDict
:return: Incremental weighted average of model weights.
:return: sqrt(m1)
:rtype: OrderedDict
"""
w = OrderedDict()
for name in m1.keys():
tensorDiff = a*m1[name] - b*m2[name]
w[name] = tensorDiff
return w
res = OrderedDict()
for key, val in m1.items():
res[key] = np.sqrt(np.array(val))

return res

def power(self, m1, a):
""" m1 raised to the power of m2.
:param m1: Current model weights with keys from torch state_dict.
:type m1: OrderedDict
:param m2: New model weights with keys from torch state_dict.
:type a: float
:return: m1.^m2
:rtype: OrderedDict
"""
res = OrderedDict()
for key, val in m1.items():
res[key] = np.power(val, a)

return res

def norm(self, m):
"""Compute the L1-norm of the tensor m. """
n = 0.0
for name, val in m.items():
n += np.linalg.norm(np.array(val), 1)

return n

def ones(self, m1, a):
res = OrderedDict()
for key, val in m1.items():
res[key] = np.ones(np.shape(val))*a

return res

def save(self, model, path=None):
""" Serialize weights to file. The serialized model must be a single binary object.
Expand Down

0 comments on commit e8238b0

Please sign in to comment.