forked from FederatedAI/FATE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hetero_nn.py
254 lines (212 loc) · 11.1 KB
/
hetero_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pipeline.component.component_base import FateComponent
from pipeline.component.nn.models.sequantial import Sequential
from pipeline.component.nn.backend.torch.interactive import InteractiveLayer
from pipeline.interface import Input
from pipeline.interface import Output
from pipeline.utils.tools import extract_explicit_parameter
from pipeline.component.nn.interface import DatasetParam
class HeteroNN(FateComponent):
@extract_explicit_parameter
def __init__(self, task_type="classification", epochs=None, batch_size=-1, early_stop="diff",
tol=1e-5, encrypt_param=None, predict_param=None, cv_param=None, interactive_layer_lr=0.1,
validation_freqs=None, early_stopping_rounds=None, use_first_metric_only=None,
floating_point_precision=23, selector_param=None, seed=100,
dataset: DatasetParam = DatasetParam(dataset_name='table'), **kwargs
):
"""
Parameters used for Hetero Neural Network.
Parameters
----------
task_type: str, task type of hetero nn model, one of 'classification', 'regression'.
interactive_layer_lr: float, the learning rate of interactive layer.
epochs: int, the maximum iteration for aggregation in training.
batch_size : int, batch size when updating model.
-1 means use all data in a batch. i.e. Not to use mini-batch strategy.
defaults to -1.
early_stop : str, accept 'diff' only in this version, default: 'diff'
Method used to judge converge or not.
a) diff: Use difference of loss between two iterations to judge whether converge.
tol: float, tolerance val for early stop
floating_point_precision: None or integer, if not None, means use floating_point_precision-bit to speed up calculation,
e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
the result by 2**floating_point_precision in the end.
callback_param: dict, CallbackParam, see federatedml/param/callback_param
encrypt_param: dict, see federatedml/param/encrypt_param
dataset_param: dict, interface defining the dataset param
early_stopping_rounds: integer larger than 0
will stop training if one metric of one validation data
doesn’t improve in last early_stopping_round rounds,
need to set validation freqs and will check early_stopping every at every validation epoch
validation_freqs: None or positive integer or container object in python
Do validation in training process or Not.
if equals None, will not do validation in train process;
if equals positive integer, will validate data every validation_freqs epochs passes;
if container object in python, will validate data if epochs belong to this container.
e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
Default: None
"""
explicit_parameters = kwargs["explict_parameters"]
explicit_parameters["optimizer"] = None
explicit_parameters["bottom_nn_define"] = None
explicit_parameters["top_nn_define"] = None
explicit_parameters["interactive_layer_define"] = None
explicit_parameters["loss"] = None
FateComponent.__init__(self, **explicit_parameters)
if "name" in explicit_parameters:
del explicit_parameters["name"]
for param_key, param_value in explicit_parameters.items():
setattr(self, param_key, param_value)
self.input = Input(self.name, data_type="multi")
self.output = Output(self.name, data_type='single')
self._module_name = "HeteroNN"
self.optimizer = None
self.bottom_nn_define = None
self.top_nn_define = None
self.interactive_layer_define = None
# model holder
self._bottom_nn_model = Sequential()
self._interactive_layer = Sequential()
self._top_nn_model = Sequential()
# role
self._role = 'common' # common/guest/host
if hasattr(self, 'dataset'):
assert isinstance(
self.dataset, DatasetParam), 'dataset must be a DatasetParam class'
self.dataset.check()
self.dataset: DatasetParam = self.dataset.to_dict()
def set_role(self, role):
self._role = role
def get_party_instance(self, role="guest", party_id=None) -> 'Component':
inst = super().get_party_instance(role, party_id)
inst.set_role(role)
return inst
def add_dataset(self, dataset_param: DatasetParam):
assert isinstance(
dataset_param, DatasetParam), 'dataset must be a DatasetParam class'
dataset_param.check()
self.dataset: DatasetParam = dataset_param.to_dict()
self._component_parameter_keywords.add("dataset")
self._component_param["dataset"] = self.dataset
def add_bottom_model(self, model):
if not hasattr(self, "_bottom_nn_model"):
setattr(self, "_bottom_nn_model", Sequential())
self._bottom_nn_model.add(model)
def set_interactive_layer(self, layer):
if self._role == 'common' or self._role == 'guest':
if not hasattr(self, "_interactive_layer"):
setattr(self, "_interactive_layer", Sequential())
assert isinstance(layer, InteractiveLayer), 'You need to add an interactive layer instance, \n' \
'you can access InteractiveLayer by:\n' \
't.nn.InteractiveLayer after fate_torch_hook(t)\n' \
'or from pipeline.component.nn.backend.torch.interactive ' \
'import InteractiveLayer'
self._interactive_layer.add(layer)
else:
raise RuntimeError(
'You can only set interactive layer in "common" or "guest" hetero nn component')
def add_top_model(self, model):
if self._role == 'host':
raise RuntimeError('top model is not allow to set on host model')
if not hasattr(self, "_top_nn_model"):
setattr(self, "_top_nn_model", Sequential())
self._top_nn_model.add(model)
def _set_optimizer(self, opt):
assert hasattr(
opt, 'to_dict'), 'opt does not have function to_dict(), remember to call fate_torch_hook(t)'
self.optimizer = opt.to_dict()
def _set_loss(self, loss):
assert hasattr(
loss, 'to_dict'), 'loss does not have function to_dict(), remember to call fate_torch_hook(t)'
loss_conf = loss.to_dict()
setattr(self, "loss", loss_conf)
def compile(self, optimizer, loss):
self._set_optimizer(optimizer)
self._set_loss(loss)
self._compile_common_network_config()
self._compile_role_network_config()
self._compile_interactive_layer()
def _compile_interactive_layer(self):
if hasattr(
self,
"_interactive_layer") and not self._interactive_layer.is_empty():
self.interactive_layer_define = self._interactive_layer.get_network_config()
self._component_param["interactive_layer_define"] = self.interactive_layer_define
def _compile_common_network_config(self):
if hasattr(
self,
"_bottom_nn_model") and not self._bottom_nn_model.is_empty():
self.bottom_nn_define = self._bottom_nn_model.get_network_config()
self._component_param["bottom_nn_define"] = self.bottom_nn_define
if hasattr(
self,
"_top_nn_model") and not self._top_nn_model.is_empty():
self.top_nn_define = self._top_nn_model.get_network_config()
self._component_param["top_nn_define"] = self.top_nn_define
def _compile_role_network_config(self):
all_party_instance = self._get_all_party_instance()
for role in all_party_instance:
for party in all_party_instance[role]["party"].keys():
all_party_instance[role]["party"][party]._compile_common_network_config(
)
all_party_instance[role]["party"][party]._compile_interactive_layer(
)
def get_bottom_model(self):
if hasattr(
self,
"_bottom_nn_model") and not getattr(
self,
"_bottom_nn_model").is_empty():
return getattr(self, "_bottom_nn_model").get_model()
bottom_models = {}
all_party_instance = self._get_all_party_instance()
for role in all_party_instance.keys():
for party in all_party_instance[role]["party"].keys():
party_inst = all_party_instance[role]["party"][party]
if party_inst is not None:
btn_model = all_party_instance[role]["party"][party].get_bottom_model(
)
if btn_model is not None:
bottom_models[party] = btn_model
return bottom_models if len(bottom_models) > 0 else None
def get_top_model(self):
if hasattr(
self,
"_top_nn_model") and not getattr(
self,
"_top_nn_model").is_empty():
return getattr(self, "_top_nn_model").get_model()
models = {}
all_party_instance = self._get_all_party_instance()
for role in all_party_instance.keys():
for party in all_party_instance[role]["party"].keys():
party_inst = all_party_instance[role]["party"][party]
if party_inst is not None:
top_model = all_party_instance[role]["party"][party].get_top_model(
)
if top_model is not None:
models[party] = top_model
return models if len(models) > 0 else None
def __getstate__(self):
state = dict(self.__dict__)
if "_bottom_nn_model" in state:
del state["_bottom_nn_model"]
if "_interactive_layer" in state:
del state["_interactive_layer"]
if "_top_nn_model" in state:
del state["_top_nn_model"]
return state