-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTFLayerPath.js
211 lines (180 loc) · 6.02 KB
/
TFLayerPath.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
//Ver 0.1.4test
//Helps creating multipath model
if(!window._guzuTF)window._guzuTF={};
window._guzuTF.TFLayerPath=class TFLayerPath{
constructor(name=""){
this.layerPath=[];
this.layerNames={};
this.name=name;
this.rename=true;
}
//TODO:needs testing
clone(){
return Object.assign(Object.create(Object.getPrototypeOf(this)), this);
}
/*
* Add layer path
* @param {string} index_ Path name
* @param {tf.layers} layer_ tensorflow layer
* @param {string | [string]} applytoIndex the index to apply this layer to. If not provided, it is considered input layer
*/
add(index_,layer_,applytoIndex=-1){
if(this.layerNames[index_]){throw("[tf.util.path] Name conflect: "+index_)}
this.layerNames[index_]={id:this.layerPath.length,defaults:{trainable:layer_.trainable}};
this.layerPath.push([index_,layer_,applytoIndex,-1]);//[ name, tf.layer , #to apply to, redirect ]
this._lastIndex=index_;
if(this.rename)layer_.name=index_;
return this;
}
//connects last added to new add
to(index_,layer_){
this.add(index_,layer_,this._lastIndex);
return this;
}
//set lastINdex
from(index_){
if(index_)this._lastIndex=index_;
return this;
}
/*
* After adding, use apply to activate tf.apply chaing to use it in a model.
* @param {path_} the path index that has output layer
*/
apply(path_,undefinedIsTop_=true){
//console.log(path_);
if (path_===undefined) {
if(undefinedIsTop_)
return this.apply(this.layerPath.length-1);//apply() starts
console.warn("Path ["+this._prevPath+"] unknown");
return null;
}
if (typeof path_==='string')return this.apply(this.getIndex(path_),false);
if (this.layerPath[path_][3]!=-1)//do redirect?
return this.apply(this.layerPath[path_][3],false);
if (this.layerPath[path_][2]===-1) return this.layerPath[path_][1];//if input layer, return it
var l=this.layerPath[path_];
var applyto;
//console.log("applyto")
if (!Array.isArray(l[2]))
applyto=this.apply(this.getIndex(this._prevPath=l[2]),false);
else {
applyto=[];
var i=-1;while(++i<l[2].length)
applyto.push(this.apply(this.getIndex(this._prevPath=l[2][i]),false));
}
//console.log("return: "+path_)
return l[1].apply(applyto,false);
}
//returns tf.layer, which is specificly useful for input
get(name){
var result=this.getPath(name);
return result!=undefined?result[1]:undefined;
}
//returns this.layerPath[name]
getPath(name){
var result=this.getIndex(name);
return result!=undefined?this.layerPath[this.getIndex(name)]:undefined;
}
//returns the # index of this name
getIndex(name){
var result;
if(isNaN(name)){
result=this.layerNames[name];
if(result)
return result.id;
}else return name;
return undefined;
//return (isNaN(name))? this.layerNames[name].id:name;
}
//{id,defaults:{trainable}}
getInfo(name){
if (!isNaN(name))name=this.layerPath[name][0];
return this.layerNames[name];
}
/*
* instead of using this path, use other path index
*/
redirect(name_,to_=-1){
if(to_!=-1){
var to_2=this.getIndex(to_);
if(to_2===undefined){to_=-1;console.warn("[tf.util.path] Unable to redirect '"+name_+"' to "+to_);}
else to_=to_2;
}
this.getPath(name_)[3]=to_;
return this;
}
/*
* @param {number | string} index_ The index of the layer or its index name
* @param {any} layer_ if layer_ is string, replace the applyto. else, replace the tf.layer
*/
replace(index_,layer_){
var l_=(Array.isArray(layer_)||typeof layer_==='string')?2:1;//console.log("replacing:"+l_)
this.getPath(index_)[l_]=layer_;
return this;
}
/*
* Switch between trainable state and untrainable state
*/
//TODO:redirect in setTrainable needs testing
setTrainable(rootIndex_,endIndex_,targetState_,exceptIfTrainableIs_){
//if root not defined, use the last one added
if (rootIndex_===undefined)rootIndex_=this.layerPath.length-1;
var r=this.getPath(rootIndex_);
if(r[3]!==-1){//redirect?
this.setTrainable(r[3],endIndex_,targetState_,exceptIfTrainableIs_);
return;
}
//set value
if(exceptIfTrainableIs_===undefined || exceptIfTrainableIs_!==this.getInfo(rootIndex_).defaults.trainable)
r[1].trainable=targetState_;
//if there is more, fix them
if (rootIndex_!==endIndex_ && r[2]!==-1)
if (Array.isArray(r[2])){
var i=-1;while(++i<r[2].length)
this.setTrainable(r[2][i],endIndex_,targetState_,exceptIfTrainableIs_);
}
else{
this.setTrainable(r[2],endIndex_,targetState_,exceptIfTrainableIs_);}
}
//restores trainable
resetTrainable(rootIndex_,endIndex_){
this.setTrainable(rootIndex_,endIndex_, true,!true);
this.setTrainable(rootIndex_,endIndex_,false,!false);
}
//args are arrays of strings
model(inputNames,outputNames){
var i;
if(inputNames && !Array.isArray(inputNames))
inputNames=[inputNames];
if(outputNames && !Array.isArray(outputNames))
outputNames=[outputNames];
if(inputNames)
for (i in inputNames)
inputNames[i]=this.get(inputNames[i]);
else
inputNames=[this.get(0)];
if(outputNames)
for (i in outputNames)
outputNames[i]=this.apply(outputNames[i]);
else
outputNames=[this.apply()];
return tf.model({inputs:inputNames,outputs:outputNames});
}
//returns a default simple model
Model(inputNames,outputNames,learningRate){
var args;
if (!isNaN(learningRate) || learningRate===undefined)
args={learningRate:learningRate||0.001};
else
args=learningRate;
//console.log(args)
var m=this.model(inputNames,outputNames);
m.compile({
optimizer:args.optimizer||tf.train.adam(args.learningRate||0.001),
metrics:args.metrics||['accuracy'],
loss:args.loss||'meanSquaredError',
});
return m;
}
}
tf.util.path=(name)=>{return new window._guzuTF.TFLayerPath(name);}