-
Notifications
You must be signed in to change notification settings - Fork 1
/
matlab2nnv.m
153 lines (122 loc) · 5.71 KB
/
matlab2nnv.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
function net = matlab2nnv(Mnetwork)
%% Check for correct inputs and process network
ntype = class(Mnetwork); % input type
if ~contains(ntype, ["SeriesNetwork", "LayerGraph", "DAGNetwork", "dlnetwork"])
error('Wrong input type. Input must be a SeriesNetwork, LayerGraph, DAGNetwork, or dlnetwork');
end
%% Process input types
% Input is a MATLAB type neural network (layergraph, seriesNetwork, dlnetwork or dagnetwork)
if ntype== "SeriesNetwork"
conns = layerGraph(Mnetwork).Connections; % get the table of connections
else
conns = Mnetwork.Connections; % get the table of connections
end
Layers = Mnetwork.Layers; % get the list of layers
%% Transform to NNV
n = length(Layers);
nnvLayers = cell(n,1);
names = strings(n,1);
% Parse network layer-by-layer
for i=1:n
L = Layers(i);
fprintf('\nParsing Layer %d... \n', i);
% Layers with no effect on the reachability analysis
if isa(L, 'nnet.cnn.layer.DropoutLayer') || isa(L, 'nnet.cnn.layer.SoftmaxLayer') || isa(L, 'nnet.cnn.layer.ClassificationOutputLayer') ...
|| isa(L,"nnet.onnx.layer.VerifyBatchSizeLayer") || isa(L, "nnet.cnn.layer.RegressionOutputLayer")
Li = PlaceholderLayer.parse(L);
% Image Input Layer
elseif isa(L, 'nnet.cnn.layer.ImageInputLayer')
Li = ImageInputLayer.parse(L);
% Convolutional 2D layer
elseif isa(L, 'nnet.cnn.layer.Convolution2DLayer')
Li = Conv2DLayer.parse(L);
% ReLU Layer (also referred to as poslin)
elseif isa(L, 'nnet.cnn.layer.ReLULayer')
Li = ReluLayer.parse(L);
% Batch Normalization Layer
elseif isa(L, 'nnet.cnn.layer.BatchNormalizationLayer')
Li = BatchNormalizationLayer.parse(L);
% Max Pooling 2D Layer
elseif isa(L, 'nnet.cnn.layer.MaxPooling2DLayer')
Li = MaxPooling2DLayer.parse(L);
% Average Pooling 2D Layer
elseif isa(L, 'nnet.cnn.layer.AveragePooling2DLayer')
Li = AveragePooling2DLayer.parse(L);
% Fully Connected Layer
elseif isa(L, 'nnet.cnn.layer.FullyConnectedLayer')
Li = FullyConnectedLayer.parse(L);
% Pixel Classification Layer (used for Semantic Segmentation output)
elseif isa(L, 'nnet.cnn.layer.PixelClassificationLayer')
Li = PixelClassificationLayer.parse(L);
% Flatten Layer
elseif isa(L, 'nnet.keras.layer.FlattenCStyleLayer') || isa(L, 'nnet.cnn.layer.FlattenLayer') || isa(L, 'nnet.onnx.layer.FlattenLayer') ...
|| isa(L, 'nnet.onnx.layer.FlattenInto2dLayer')
Li = FlattenLayer.parse(L);
% Sigmoid Layer (also referred to as logsig)
elseif isa(L, 'nnet.keras.layer.SigmoidLayer') || isa(L, 'nnet.onnx.layer.SigmoidLayer')
Li = SigmoidLayer.parse(L);
% ElementWise Affine Layer (often used as a bias layer after FC layers)
elseif isa(L, 'nnet.onnx.layer.ElementwiseAffineLayer')
Li = ElementwiseAffineLayer.parse(L);
% Feature input layer
elseif isa(L, 'nnet.cnn.layer.FeatureInputLayer')
Li = FeatureInputLayer.parse(L);
% Transposed Convolution 2D Layer
elseif isa(L, 'nnet.cnn.layer.TransposedConvolution2DLayer')
Li = TransposedConv2DLayer.parse(L);
% Max Unpooling 2D Layer
elseif isa(L, 'nnet.cnn.layer.MaxUnpooling2DLayer')
Li = MaxUnpooling2DLayer.parse(L, conns);
pairedMaxPoolingName = NN.getPairedMaxPoolingName(connects, Li.Name);
Li.setPairedMaxPoolingName(pairedMaxPoolingName);
% Depth Concatenation Layer (common in uNets)
elseif isa(L, 'nnet.cnn.layer.DepthConcatenationLayer')
Li = DepthConcatenationLayer.parse(L);
% Concatenation Layer (concat dim part of layer properties)
elseif isa(L, 'nnet.cnn.layer.ConcatenationLayer')
Li = ConcatenationLayer.parse(L);
% Reshape Layer (custom created after parsing ONNX layers)
elseif contains(class(L), "ReshapeLayer")
Li = ReshapeLayer.parse(L);
% Custom flatten layers (avoid if possible)
elseif contains(class(L), ["flatten"; "Flatten"])
% Check previous layer to see if we can neglect this one in the analysis
for k=i-1:-1:1
if contains(class(nnvLayers{k}), 'Input')
if ~strcmp(nnvLayers{k}.Normalization, 'none')
fprintf('Layer %d is a %s which have not supported yet in nnv, please consider removing this layer for the analysis \n', i, class(L));
error('Unsupported Class of Layer');
end
elseif ~isa(nnvLayers{k}, 'PlaceholderLayer')
fprintf('Layer %d is a %s which have not supported yet in nnv, please consider removing this layer for the analysis \n', i, class(L));
error('Unsupported Class of Layer');
end
end
% If we can neglect all previous layers, reinitialize layers and parse them again as placeholder layers
nnvLayers = cell(n,1);
% Parse all previous layers again
for li = 1:i-1
L = Layers(li);
Li = PlaceholderLayer.parse(L);
nnvLayers{li} = Li;
end
% Parse current flatten layer
L = Layers(i);
Li = PlaceholderLayer.parse(L);
% All other layers are currently not supported in NNV
else
fprintf('Layer %d is a %s which have not supported yet in nnv, please consider removing this layer for the analysis \n', i, class(L));
error('Unsupported Class of Layer');
end
% Add layer name to list
names(i) = string(L.Name);
nnvLayers{i} = Li;
end
indxs = 1:n;
% Assigning layer names to correspnding index
name2number = containers.Map(names,indxs);
% ConnectionsTable = table(new_sources, new_dests, 'VariableNames', {'Source', 'Destination'});
% Create neural network
net = NN(nnvLayers, conns);
net.name2indx = name2number;
end