Skip to content

Commit

Permalink
Switched to numel method to obtain the number of parameters. (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyHao authored Nov 13, 2024
1 parent 48d4660 commit 4505e8f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
clients:
# Type
type: simple_he

# The total number of clients
total_clients: 10

# encrypt ratio
encrypt_ratio: 0.1
random_mask: False

# The number of clients selected in each round
per_round: 5

# Should the clients compute test accuracy locally?
do_test: false

server:
address: 127.0.0.1
port: 8001
random_seed: 1
simulate_wall_time: true

data:
# The training and testing dataset
datasource: CIFAR10

# Number of samples in each partition
partition_size: 1000

# IID or non-IID?
sampler: iid

trainer:
# The maximum number of training rounds
rounds: 25

# The maximum number of clients running concurrently
max_concurrency: 5

# The target accuracy
target_accuracy: 0.99

# The machine learning model
model_name: resnet_18

# Number of epoches for local training in each communication round
epochs: 5
batch_size: 32
optimizer: SGD
algorithm:
# Aggregation algorithm
type: fedavg

results:
# Write the following parameter(s) into a CSV
types: round, elapsed_time, accuracy, comm_overhead

parameters:
model:
num_classes: 10

optimizer:
lr: 0.01
momentum: 0.9
weight_decay: 0.0
3 changes: 2 additions & 1 deletion plato/servers/fedavg_he.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
A federated learning server using federated averaging to aggregate updates after homomorphic encryption.
"""

from functools import reduce
from plato.servers import fedavg
from plato.utils import homo_enc
Expand Down Expand Up @@ -36,7 +37,7 @@ def configure(self) -> None:

for key in extract_model.keys():
self.weight_shapes[key] = extract_model[key].size()
self.para_nums[key] = reduce(lambda a, b: a * b, self.weight_shapes[key])
self.para_nums[key] = extract_model[key].numel()

self.encrypted_model = homo_enc.encrypt_weights(
extract_model, True, self.context, []
Expand Down

0 comments on commit 4505e8f

Please sign in to comment.