-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_hete.py
95 lines (81 loc) · 2.72 KB
/
main_hete.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import importlib
import torch
import numpy as np
import logging
import random
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S')
from data.get_data import get_dataset_fed, get_dataset_fixed
from local.client import Client
from util.options import args_parse
# some datasets need fixed client_nums
fixed_dataset = [
'shakespeare'
]
class_num = {
'cifar10' : 10,
'cifar100' : 100,
'mnist' : 10,
'fmnist' : 10,
'shakespeare' : 80,
}
def get_debug():
import debugpy
import setproctitle
setproctitle.setproctitle("fedhete")
debugpy.listen(10001)
debugpy.wait_for_client()
def read_options():
options = args_parse()
print(options)
# set seed to keep same dataset
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)
# prepare dataset
if options['dataset'] in fixed_dataset:
client_num, _, main_test_dataset, \
clients_trainset_list, clients_testset_list= \
get_dataset_fixed(options['dataset'])
options['client_nums'] = client_num
else:
_, main_test_dataset, \
clients_trainset_list, clients_testset_list= \
get_dataset_fed(options['dataset'], class_num[options['dataset']],
options['client_nums'], options['part_method'], options['alpha'])
# set seed to initialize different parameters and selected client
seed = options['seed']
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
# create model
global_model = None
# create trainer
global_trainer = None
# create clients
clients = []
for i in range(options['client_nums']):
trainer = None
client = Client(i, global_model, trainer, clients_trainset_list[i],
clients_testset_list[i],
batch_size=options['batch_size'])
clients.append(client)
# create optimizer in server
assert options['optimizer'] in ['fedhete']
optim_path = f'server.optimizer.%s' % (options['optimizer'])
optim_lib = importlib.import_module(optim_path)
optim_class = getattr(optim_lib, 'Server')
global_server = optim_class(global_model, global_trainer, main_test_dataset, clients, options)
logging.debug('server starts to communicate with clients.')
return global_server
def main():
global_server = read_options()
global_server.run()
if __name__ == '__main__':
get_debug()
main()