-
Notifications
You must be signed in to change notification settings - Fork 1
/
ss.m
622 lines (573 loc) · 22.8 KB
/
ss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
function [x31, state] = ss(input_1, params, varargin)
%SS Function implementing an imported ONNX network.
%
% THIS FILE WAS AUTO-GENERATED BY importONNXFunction.
% ONNX Operator Set Version: 9
%
% Variable names in this function are taken from the original ONNX file.
%
% [X31] = ss(INPUT_1, PARAMS)
% - Evaluates the imported ONNX network SS with input(s)
% INPUT_1 and the imported network parameters in PARAMS. Returns
% network output(s) in X31.
%
% [X31, STATE] = ss(INPUT_1, PARAMS)
% - Additionally returns state variables in STATE. When training,
% use this form and set TRAINING to true.
%
% [__] = ss(INPUT_1, PARAMS, 'NAME1', VAL1, 'NAME2', VAL2, ...)
% - Specifies additional name-value pairs described below:
%
% 'Training'
% Boolean indicating whether the network is being evaluated for
% prediction or training. If TRAINING is true, state variables
% will be updated.
%
% 'InputDataPermutation'
% 'auto' - Automatically attempt to determine the permutation
% between the dimensions of the input data and the dimensions of
% the ONNX model input. For example, the permutation from HWCN
% (MATLAB standard) to NCHW (ONNX standard) uses the vector
% [4 3 1 2]. See the documentation for IMPORTONNXFUNCTION for
% more information about automatic permutation.
%
% 'none' - Input(s) are passed in the ONNX model format. See 'Inputs'.
%
% numeric vector - The permutation vector describing the
% transformation between input data dimensions and the expected
% ONNX input dimensions.%
% cell array - If the network has multiple inputs, each cell
% contains 'auto', 'none', or a numeric vector.
%
% 'OutputDataPermutation'
% 'auto' - Automatically attempt to determine the permutation
% between the dimensions of the output and a conventional MATLAB
% dimension ordering. For example, the permutation from NC (ONNX
% standard) to CN (MATLAB standard) uses the vector [2 1]. See
% the documentation for IMPORTONNXFUNCTION for more information
% about automatic permutation.
%
% 'none' - Return output(s) as given by the ONNX model. See 'Outputs'.
%
% numeric vector - The permutation vector describing the
% transformation between the ONNX output dimensions and the
% desired output dimensions.%
% cell array - If the network has multiple outputs, each cell
% contains 'auto', 'none' or a numeric vector.
%
% Inputs:
% -------
% INPUT_1
% - Input(s) to the ONNX network.
% The input size(s) expected by the ONNX file are:
% INPUT_1: [1, 3, 32, 32] Type: FLOAT
% By default, the function will try to permute the input(s)
% into this dimension ordering. If the default is incorrect,
% use the 'InputDataPermutation' argument to control the
% permutation.
%
%
% PARAMS - Network parameters returned by 'importONNXFunction'.
%
%
% Outputs:
% --------
% X31
% - Output(s) of the ONNX network.
% Without permutation, the size(s) of the outputs are:
% X31: [1, 10] Type: FLOAT
% By default, the function will try to permute the output(s)
% from this dimension ordering into a conventional MATLAB
% ordering. If the default is incorrect, use the
% 'OutputDataPermutation' argument to control the permutation.
%
% STATE - (Optional) State variables. When TRAINING is true, these will
% have been updated from the original values in PARAMS.State.
%
%
% See also importONNXFunction
% Preprocess the input data and arguments:
[input_1, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(input_1, params, varargin{:});
% Put all variables into a single struct to implement dynamic scoping:
[Vars, NumDims] = packageVariables(params, {'input_1'}, {input_1}, [4]);
% Call the top-level graph function:
[x31, x31NumDims, state] = torch_jit_exportGraph1000(input_1, NumDims.input_1, Vars, NumDims, Training, params.State);
% Postprocess the output data
[x31] = postprocessOutput(x31, outputDataPerms, anyDlarrayInputs, Training, varargin{:});
end
function [x31, x31NumDims1017, state] = torch_jit_exportGraph1000(input_1, input_1NumDims1016, Vars, NumDims, Training, state)
% Function implementing the graph 'torch_jit_exportGraph1000'
% Update Vars and NumDims from the graph's formal input parameters. Note that state variables are already in Vars.
Vars.input_1 = input_1;
NumDims.input_1 = input_1NumDims1016;
% Execute the operators:
% Conv:
[weights, bias, stride, dilationFactor, padding, dataFormat, NumDims.x15] = prepareConvArgs(Vars.blocks_layers_1_conv_weight, Vars.blocks_layers_1_conv_bias, Vars.ConvStride1001, Vars.ConvDilationFactor1002, Vars.ConvPadding1003, 1, NumDims.input_1, NumDims.blocks_layers_1_conv_weight);
Vars.x15 = dlconv(Vars.input_1, weights, bias, 'Stride', stride, 'DilationFactor', dilationFactor, 'Padding', padding, 'DataFormat', dataFormat);
% Relu:
Vars.x16 = relu(Vars.x15);
NumDims.x16 = NumDims.x15;
% Conv:
[weights, bias, stride, dilationFactor, padding, dataFormat, NumDims.x17] = prepareConvArgs(Vars.blocks_layers_3_conv_weight, Vars.blocks_layers_3_conv_bias, Vars.ConvStride1004, Vars.ConvDilationFactor1005, Vars.ConvPadding1006, 1, NumDims.x16, NumDims.blocks_layers_3_conv_weight);
Vars.x17 = dlconv(Vars.x16, weights, bias, 'Stride', stride, 'DilationFactor', dilationFactor, 'Padding', padding, 'DataFormat', dataFormat);
% Relu:
Vars.x18 = relu(Vars.x17);
NumDims.x18 = NumDims.x17;
% Conv:
[weights, bias, stride, dilationFactor, padding, dataFormat, NumDims.x19] = prepareConvArgs(Vars.blocks_layers_5_conv_weight, Vars.blocks_layers_5_conv_bias, Vars.ConvStride1007, Vars.ConvDilationFactor1008, Vars.ConvPadding1009, 1, NumDims.x18, NumDims.blocks_layers_5_conv_weight);
Vars.x19 = dlconv(Vars.x18, weights, bias, 'Stride', stride, 'DilationFactor', dilationFactor, 'Padding', padding, 'DataFormat', dataFormat);
% Relu:
Vars.x20 = relu(Vars.x19);
NumDims.x20 = NumDims.x19;
% Shape:
[Vars.x21, NumDims.x21] = onnxShape(Vars.x20, NumDims.x20);
% Gather:
[Vars.x23, NumDims.x23] = onnxGather(Vars.x21, Vars.x22, 0, NumDims.x21, NumDims.x22);
% Unsqueeze:
[shape, NumDims.x25] = prepareUnsqueezeArgs(Vars.x23, Vars.UnsqueezeAxes1010, NumDims.x23);
Vars.x25 = reshape(Vars.x23, shape);
% Unsqueeze:
[shape, NumDims.x26] = prepareUnsqueezeArgs(Vars.x24, Vars.UnsqueezeAxes1011, NumDims.x24);
Vars.x26 = reshape(Vars.x24, shape);
% Concat:
[Vars.x27, NumDims.x27] = onnxConcat(0, {Vars.x25, Vars.x26}, [NumDims.x25, NumDims.x26]);
% Reshape:
[shape, NumDims.x28] = prepareReshapeArgs(Vars.x20, Vars.x27, NumDims.x20, 0);
Vars.x28 = reshape(Vars.x20, shape{:});
% Gemm:
[A, B, C, alpha, beta, NumDims.x29] = prepareGemmArgs(Vars.x28, Vars.blocks_layers_8_linear_weight, Vars.blocks_layers_8_linear_bias, Vars.Gemmalpha1012, Vars.Gemmbeta1013, 0, 1, NumDims.blocks_layers_8_linear_bias);
Vars.x29 = alpha*B*A + beta*C;
% Relu:
Vars.x30 = relu(Vars.x29);
NumDims.x30 = NumDims.x29;
% Gemm:
[A, B, C, alpha, beta, NumDims.x31] = prepareGemmArgs(Vars.x30, Vars.blocks_layers_10_linear_weight, Vars.blocks_layers_10_linear_bias, Vars.Gemmalpha1014, Vars.Gemmbeta1015, 0, 1, NumDims.blocks_layers_10_linear_bias);
Vars.x31 = alpha*B*A + beta*C;
% Set graph output arguments from Vars and NumDims:
x31 = Vars.x31;
x31NumDims1017 = NumDims.x31;
% Set output state from Vars:
state = updateStruct(state, Vars);
end
function [inputDataPerms, outputDataPerms, Training] = parseInputs(input_1, numDataOutputs, params, varargin)
% Function to validate inputs to ss:
p = inputParser;
isValidArrayInput = @(x)isnumeric(x) || isstring(x);
isValidONNXParameters = @(x)isa(x, 'ONNXParameters');
addRequired(p, 'input_1', isValidArrayInput);
addRequired(p, 'params', isValidONNXParameters);
addParameter(p, 'InputDataPermutation', 'auto');
addParameter(p, 'OutputDataPermutation', 'auto');
addParameter(p, 'Training', false);
parse(p, input_1, params, varargin{:});
inputDataPerms = p.Results.InputDataPermutation;
outputDataPerms = p.Results.OutputDataPermutation;
Training = p.Results.Training;
if isnumeric(inputDataPerms)
inputDataPerms = {inputDataPerms};
end
if isstring(inputDataPerms) && isscalar(inputDataPerms) || ischar(inputDataPerms)
inputDataPerms = repmat({inputDataPerms},1,1);
end
if isnumeric(outputDataPerms)
outputDataPerms = {outputDataPerms};
end
if isstring(outputDataPerms) && isscalar(outputDataPerms) || ischar(outputDataPerms)
outputDataPerms = repmat({outputDataPerms},1,numDataOutputs);
end
end
function [input_1, Training, outputDataPerms, anyDlarrayInputs] = preprocessInput(input_1, params, varargin)
% Parse input arguments
[inputDataPerms, outputDataPerms, Training] = parseInputs(input_1, 1, params, varargin{:});
anyDlarrayInputs = any(cellfun(@(x)isa(x, 'dlarray'), {input_1}));
% Make the input variables into unlabelled dlarrays:
input_1 = makeUnlabeledDlarray(input_1);
% Permute inputs if requested:
input_1 = permuteInputVar(input_1, inputDataPerms{1}, 4);
% Check input size(s):
checkInputSize(size(input_1), {1 3 32 32}, "input_1");
end
function [x31] = postprocessOutput(x31, outputDataPerms, anyDlarrayInputs, Training, varargin)
% Set output type:
if ~anyDlarrayInputs && ~Training
if isdlarray(x31)
x31 = extractdata(x31);
end
end
% Permute outputs if requested:
x31 = permuteOutputVar(x31, outputDataPerms{1}, 2);
end
%% dlarray functions implementing ONNX operators:
function [Y, numDimsY] = onnxConcat(ONNXAxis, XCell, numDimsXArray)
% Concatentation that treats all empties the same. Necessary because
% dlarray.cat does not allow, for example, cat(1, 1x1, 1x0) because the
% second dimension sizes do not match.
numDimsY = numDimsXArray(1);
XCell(cellfun(@isempty, XCell)) = [];
if isempty(XCell)
Y = dlarray([]);
else
if ONNXAxis<0
ONNXAxis = ONNXAxis + numDimsY;
end
DLTAxis = numDimsY - ONNXAxis;
Y = cat(DLTAxis, XCell{:});
end
end
function [Y, numDimsY] = onnxGather(X, ONNXIdx, ONNXAxis, numDimsX, numDimsIdx)
% Function implementing the ONNX Gather operator
% In ONNX, 'Gather' first indexes into dimension ONNXAxis of data, using
% the contents of ONNXIdx as the indices. Then, it reshapes the ONNXAxis
% into the shape of ONNXIdx.
% Example 1:
% Suppose data has shape [2 3 4 5], ONNXIdx has shape [6 7], and axis=1.
% The result has shape [2 6 7 4 5].
% Example 2:
% Suppose data has shape [2 3 4 5], ONNXIdx has shape [6], and axis=1.
% The result has shape [2 6 4 5].
% Example 3:
% Suppose data has shape [2 3 4 5], ONNXIdx has shape [] (a scalar), and axis=1.
% The result has shape [2 4 5].
%
% Since we're using reverse indexing relative to ONNX, in this function
% data and ONNXIdx both have reversed dimension ordering.
numDimsY = numDimsIdx + (numDimsX - 1);
if isempty(X)
Y = X;
return;
end
% (1) First, do the subsref part of Gather
if ONNXAxis<0
ONNXAxis = ONNXAxis + numDimsX; % Axis can be negative. Convert it to its positive equivalent.
end
dltAxis = numDimsX - ONNXAxis; % Convert axis to DLT. ONNXAxis is origin 0 and we index from the end
ONNXIdx(ONNXIdx<0) = ONNXIdx(ONNXIdx<0) + size(X, dltAxis); % ONNXIdx can have negative components. Make them positive.
dltIdx = extractdata(ONNXIdx) + 1; % ONNXIdx is origin-0 in ONNX, so add 1 to get dltIdx
% Use subsref to index into data
Indices.subs = repmat({':'}, 1, numDimsX);
Indices.subs{dltAxis} = dltIdx(:); % Index as a column to ensure the output is 1-D in the indexed dimension (for now).
Indices.type = '()';
Y = subsref(X, Indices);
% (2) Now do the reshaping part of Gather
shape = size(Y, 1:numDimsX);
if numDimsIdx == 0
% Delete the indexed dimension
shape(dltAxis) = [];
elseif numDimsIdx > 1
% Reshape the indexed dimension into the shape of ONNXIdx
shape = [shape(1:dltAxis-1) size(ONNXIdx, 1:numDimsIdx) shape(dltAxis+1:end)];
end
% Extend the shape to 2D so it's valid MATLAB
if numel(shape) < 2
shape = [shape ones(1,2-numel(shape))];
end
Y = reshape(Y, shape);
end
function [Y, numDimsY] = onnxShape(X, numDimsX)
% Implements the ONNX Shape operator
% Return the reverse ONNX shape as a 1D column vector
switch numDimsX
case 0
if isempty(X)
Y = dlarray(0);
else
Y = dlarray(1);
end
case 1
if isempty(X)
Y = dlarray(0);
else
Y = dlarray(size(X,1));
end
otherwise
Y = dlarray(fliplr(size(X, 1:numDimsX))');
end
numDimsY = 1;
end
function [weights, bias, stride, dilationFactor, padding, dataFormat, numDimsY] = prepareConvArgs(...
weights, bias, stride, dilationFactor, padding, numWtGroups, numDimsX, numDimsW)
% Prepares arguments for implementing the ONNX Conv operator
% Weights: The ONNX weight dim is Fcxyz..., where c=C/G, G is numGroups,
% and xyz... are spatial dimensions. DLT "weights" here is the flip of
% that, or ...zyxcF. dlconv requires ...zyxcfG, where f=F/G. So reshape to
% split the last dimension.
sizeW = size(weights, 1:numDimsW);
F = sizeW(end);
newWSize = [sizeW(1:numDimsW-1), F/numWtGroups, numWtGroups];
weights = reshape(weights, newWSize);
% bias
if isempty(bias)
bias = 0;
end
bias = dlarray(bias(:),'CU');
% Derive missing default attributes from weight tensor
numSpatialDims = numDimsW-2;
if isempty(padding)
padding = zeros(1, 2*numSpatialDims);
end
if isempty(stride)
stride = ones(1,numSpatialDims);
end
if isempty(dilationFactor)
dilationFactor = ones(1,numSpatialDims);
end
% Make the attributes non-dlarrays:
if isa(stride, 'dlarray')
stride = extractdata(stride);
end
if isa(dilationFactor, 'dlarray')
dilationFactor = extractdata(dilationFactor);
end
if isa(padding, 'dlarray')
padding = extractdata(padding);
end
% Make the attributes double row vectors, and flip their dimension ordering
% to reverse-onnx:
stride = fliplr(double(stride(:)'));
dilationFactor = fliplr(double(dilationFactor(:)'));
if isnumeric(padding) % padding can be "same"
% ONNX: [x1_begin, ..., xn_begin, x1_end, ...,xn_end]
% DLT: [xn_begin, ..., x1_begin;
% xn_end, ..., x1_end] (Note the lrflip and semicolon)
padding = fliplr(transpose(reshape(padding, [], 2)));
end
% Set dataformat and numdims
dataFormat = [repmat('S', 1, numDimsX-2) 'CB'];
numDimsY = numDimsX;
end
function [A, B, C, alpha, beta, numDimsY] = prepareGemmArgs(A, B, C, alpha, beta, transA, transB, numDimsC)
% Prepares arguments for implementing the ONNX Gemm operator
if transA
A = A';
end
if transB
B = B';
end
if numDimsC < 2
C = C(:); % C can be broadcast to [N M]. Make C a col vector ([N 1])
end
numDimsY = 2;
% Y=B*A because we want (AB)'=B'A', and B and A are already transposed.
end
function [DLTShape, numDimsY] = prepareReshapeArgs(X, ONNXShape, numDimsX, allowzero)
% Prepares arguments for implementing the ONNX Reshape operator
ONNXShape = flip(extractdata(ONNXShape)); % First flip the shape to make it correspond to the dimensions of X.
% In ONNX, 0 means "unchanged" if allowzero is false, and -1 means "infer". In DLT, there is no
% "unchanged", and [] means "infer".
DLTShape = num2cell(ONNXShape); % Make a cell array so we can include [].
% Replace zeros with the actual size if allowzero is true
if any(ONNXShape==0) && allowzero==0
i0 = find(ONNXShape==0);
DLTShape(i0) = num2cell(size(X, numDimsX - numel(ONNXShape) + i0)); % right-align the shape vector and dims
end
if any(ONNXShape == -1)
% Replace -1 with []
i = ONNXShape == -1;
DLTShape{i} = [];
end
if numel(DLTShape)==1
DLTShape = [DLTShape 1];
end
numDimsY = numel(ONNXShape);
end
function [newShape, numDimsY] = prepareUnsqueezeArgs(X, ONNXAxes, numDimsX)
% Prepares arguments for implementing the ONNX Unsqueeze operator
numDimsY = numDimsX + numel(ONNXAxes);
ONNXAxes = extractdata(ONNXAxes);
ONNXAxes(ONNXAxes<0) = ONNXAxes(ONNXAxes<0) + numDimsY;
ONNXAxes = sort(ONNXAxes); % increasing order
if numDimsY == 1
newShape = size(X);
else
DLTAxes = flip(numDimsY - ONNXAxes); % increasing order
newShape = ones(1, numDimsY);
posToSet = setdiff(1:numDimsY, DLTAxes, 'stable');
newShape(posToSet) = size(X, 1:numel(posToSet));
end
end
%% Utility functions:
function s = appendStructs(varargin)
% s = appendStructs(s1, s2,...). Assign all fields in s1, s2,... into s.
if isempty(varargin)
s = struct;
else
s = varargin{1};
for i = 2:numel(varargin)
fromstr = varargin{i};
fs = fieldnames(fromstr);
for j = 1:numel(fs)
s.(fs{j}) = fromstr.(fs{j});
end
end
end
end
function checkInputSize(inputShape, expectedShape, inputName)
if numel(expectedShape)==0
% The input is a scalar
if ~isequal(inputShape, [1 1])
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, "[1,1]", inputSizeStr));
end
elseif numel(expectedShape)==1
% The input is a vector
if ~shapeIsColumnVector(inputShape) || ~iSizesMatch({inputShape(1)}, expectedShape)
expectedShape{2} = 1;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
else
% The input has 2 dimensions or more
% The input dimensions have been reversed; flip them back to compare to the
% expected ONNX shape.
inputShape = fliplr(inputShape);
% If the expected shape has fewer dims than the input shape, error.
if numel(expectedShape) < numel(inputShape)
expectedSizeStr = strjoin(["[", strjoin(string(expectedShape), ","), "]"], "");
error(message('nnet_cnn_onnx:onnx:InputHasGreaterNDims', inputName, expectedSizeStr));
end
% Prepad the input shape with trailing ones up to the number of elements in
% expectedShape
inputShape = num2cell([ones(1, numel(expectedShape) - length(inputShape)) inputShape]);
% Find the number of variable size dimensions in the expected shape
numVariableInputs = sum(cellfun(@(x) isa(x, 'char') || isa(x, 'string'), expectedShape));
% Find the number of input dimensions that are not in the expected shape
% and cannot be represented by a variable dimension
nonMatchingInputDims = setdiff(string(inputShape), string(expectedShape));
numNonMatchingInputDims = numel(nonMatchingInputDims) - numVariableInputs;
expectedSizeStr = makeSizeString(expectedShape);
inputSizeStr = makeSizeString(inputShape);
if numNonMatchingInputDims == 0 && ~iSizesMatch(inputShape, expectedShape)
% The actual and expected input dimensions match, but in
% a different order. The input needs to be permuted.
error(message('nnet_cnn_onnx:onnx:InputNeedsPermute',inputName, expectedSizeStr, inputSizeStr));
elseif numNonMatchingInputDims > 0
% The actual and expected input sizes do not match.
error(message('nnet_cnn_onnx:onnx:InputNeedsResize',inputName, expectedSizeStr, inputSizeStr));
end
end
end
function doesMatch = iSizesMatch(inputShape, expectedShape)
% Check whether the input and expected shapes match, in order.
% Size elements match if (1) the elements are equal, or (2) the expected
% size element is a variable (represented by a character vector or string)
doesMatch = true;
for i=1:numel(inputShape)
if ~(isequal(inputShape{i},expectedShape{i}) || ischar(expectedShape{i}) || isstring(expectedShape{i}))
doesMatch = false;
return
end
end
end
function sizeStr = makeSizeString(shape)
sizeStr = strjoin(["[", strjoin(string(shape), ","), "]"], "");
end
function isVec = shapeIsColumnVector(shape)
if numel(shape) == 2 && shape(2) == 1
isVec = true;
else
isVec = false;
end
end
function X = makeUnlabeledDlarray(X)
% Make numeric X into an unlabelled dlarray
if isa(X, 'dlarray')
X = stripdims(X);
elseif isnumeric(X)
if isinteger(X)
% Make ints double so they can combine with anything without
% reducing precision
X = double(X);
end
X = dlarray(X);
end
end
function [Vars, NumDims] = packageVariables(params, inputNames, inputValues, inputNumDims)
% inputNames, inputValues are cell arrays. inputRanks is a numeric vector.
Vars = appendStructs(params.Learnables, params.Nonlearnables, params.State);
NumDims = params.NumDimensions;
% Add graph inputs
for i = 1:numel(inputNames)
Vars.(inputNames{i}) = inputValues{i};
NumDims.(inputNames{i}) = inputNumDims(i);
end
end
function X = permuteInputVar(X, userDataPerm, onnxNDims)
% Returns reverse-ONNX ordering
if onnxNDims == 0
return;
elseif onnxNDims == 1 && isvector(X)
X = X(:);
return;
elseif isnumeric(userDataPerm)
% Permute into reverse ONNX ordering
if numel(userDataPerm) ~= onnxNDims
error(message('nnet_cnn_onnx:onnx:InputPermutationSize', numel(userDataPerm), onnxNDims));
end
perm = fliplr(userDataPerm);
elseif isequal(userDataPerm, 'auto') && onnxNDims == 4
% Permute MATLAB HWCN to reverse onnx (WHCN)
perm = [2 1 3 4];
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(X);
else
% userDataPerm is either 'none' or 'auto' with no default, which means
% it's already in onnx ordering, so just make it reverse onnx
perm = max(2,onnxNDims):-1:1;
end
X = permute(X, perm);
end
function Y = permuteOutputVar(Y, userDataPerm, onnxNDims)
switch onnxNDims
case 0
perm = [];
case 1
if isnumeric(userDataPerm)
% Use the user's permutation because Y is a column vector which
% already matches ONNX.
perm = userDataPerm;
elseif isequal(userDataPerm, 'auto')
% Treat the 1D onnx vector as a 2D column and transpose it
perm = [2 1];
else
% userDataPerm is 'none'. Leave Y alone because it already
% matches onnx.
perm = [];
end
otherwise
% ndims >= 2
if isnumeric(userDataPerm)
% Use the inverse of the user's permutation. This is not just the
% flip of the permutation vector.
perm = onnxNDims + 1 - userDataPerm;
elseif isequal(userDataPerm, 'auto')
if onnxNDims == 2
% Permute reverse ONNX CN to DLT CN (do nothing)
perm = [];
elseif onnxNDims == 4
% Permute reverse onnx (WHCN) to MATLAB HWCN
perm = [2 1 3 4];
else
% User wants the output in ONNX ordering, so just reverse it from
% reverse onnx
perm = onnxNDims:-1:1;
end
elseif isequal(userDataPerm, 'as-is')
% Do not permute the input
perm = 1:ndims(Y);
else
% userDataPerm is 'none', so just make it reverse onnx
perm = onnxNDims:-1:1;
end
end
if ~isempty(perm)
Y = permute(Y, perm);
end
end
function s = updateStruct(s, t)
% Set all existing fields in s from fields in t, ignoring extra fields in t.
for name = transpose(fieldnames(s))
s.(name{1}) = t.(name{1});
end
end