-
Notifications
You must be signed in to change notification settings - Fork 9
/
flexBox.m
389 lines (319 loc) · 13.5 KB
/
flexBox.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
classdef flexBox < handle
properties
params %FlexBox params (tolerances, number of iterations, etc. see documentation)
duals %cell array of dual terms
x %current iteration values for primal part
xOld %old iteration values for primal part
xBar %overrelaxed value for x
y %current iteration values for dual part
yOld %old iteration values for dual part
yTilde %intermediate value for primal dual algortihm
dims %dimensionals for primal variables
numberPvars %internal unique identifiers for primal variables
DcP %which primal variable(s) belongs to which dual part
DcD %which dual variables(s) belong to which dual part
xError %error estimates for primal parts
yError %error estimates for dual parts
error %combined error estimate
firstRun %internal flag specifying if algortihm has already been run
end
methods
function obj = flexBox
% Constructor
obj.params.tol = 1e-5;
obj.params.maxIt = 10000;
obj.params.checkError = 100;
obj.params.theta = 1;
obj.params.verbose = 0;
obj.params.showPrimals = 0;
obj.params.tryCPP = 0;
obj.params.relativePathToMEX = 'flexBox_CPP/';
obj.duals = {};
obj.x = {};
obj.xOld = {};
obj.xBar = {};
obj.y = {};
obj.yOld = {};
obj.dims = {};
obj.numberPvars = 0;
obj.DcP = {};
obj.DcD = {};
obj.xError = {};
obj.yError = {};
obj.firstRun = true;
end
function number = addPrimalVar(obj,dims)
%number = addPrimalVar(dims)
%adds a primal var of dimensions #dims to FlexBox and returns
%the internal #number
if (isscalar(dims))
dims = [dims,1];
end
numberElements = prod(dims);
obj.x{end+1} = zeros(numberElements,1);
obj.xOld{end+1} = zeros(numberElements,1);
obj.xBar{end+1} = zeros(numberElements,1);
obj.dims{end+1} = dims;
obj.numberPvars = obj.numberPvars + 1;
number = obj.numberPvars;
end
function numberDuals = addTerm(obj,term,corresponding)
%numberDuals = addTerm(obj,term,corresponding)
%adds a functional term to FlexBox. The variable #corresponding
%specifies the internal number of corresponding primal variables.
%The output #numberDuals contains internal number(s) of the created
%dual variables
if (numel(corresponding) ~= term.numPrimals)
error(['Number of corresponding primals wrong. Expecting ',num2str(term.numPrimals),' variables'])
end
%this is a dual part
obj.duals{end+1} = term;
%save corresponding Dual->Primal,Dual->Dual
obj.DcP{end+1} = corresponding;
obj.DcD{end+1} = numel(obj.y) + 1 : numel(obj.y) + term.numVars;
numberDuals = obj.DcD{end};
%create variables
for i=1:term.numVars
obj.y{end+1} = zeros(term.length{i},1);
obj.yOld{end+1} = zeros(term.length{i},1);
end
end
function showPrimal(obj,number)
%helper function to display primal with internal number #number as 2D
%or 3D image
if (numel(obj.dims{number}) == 2)
figure(number);clf;imagesc(reshape(obj.x{number},obj.dims{number}),[0,1]);axis image;colormap(gray)
elseif (numel(obj.dims{number}) == 3)
dims2 = obj.dims{number}(1:2);
pixelPerSlice = prod(obj.dims{number}(1:2));
showSlice = 1;
indexStart = (showSlice-1)*pixelPerSlice + 1;
endStart = (showSlice)*pixelPerSlice;
figure(number);clf;imagesc(reshape(obj.x{number}(indexStart:endStart),dims2),[0,1]);axis image;colormap(gray)
end
end
function showPrimals(obj)
%displays all primal variables (if 2D or 3D)
for i=1:numel(obj.x)
obj.showPrimal(i);
end
drawnow;
end
function u = getPrimal(obj,number)
%u = getPrimal(number)
%returns the primal variable specified by #number and reshapes the
%variable to the correct size
u = reshape(obj.x{number},obj.dims{number});
end
function y = getDual(obj,number)
%y = getDual(number)
%returns the dual variable specified by #number as a vector
y = obj.y{number};
end
function runAlgorithm(obj,varargin)
%runAlgorithm
%executes FlexBox and resets the internal iteration counter.
%The execution can be terminated at any time without loosing
%the current state
vararginParser;
if (exist('noInit','var') && noInit == 1)
else
%initialize tau and sigma
obj.init();
end
%check if C++ module is activated and compiled
if (obj.checkCPP())
obj.doCPP();
else
reverseStr = [];
if obj.firstRun
obj.error = Inf;
obj.firstRun = false;
else
obj.error = obj.calculateError();
end
iteration = 1;
while obj.error > obj.params.tol && iteration <= obj.params.maxIt
obj.doIteration;
if (mod(iteration,obj.params.checkError) == 0)
obj.error = obj.calculateError();
reverseStr = printToCmd( reverseStr, sprintf(['Iteration: #%d : Residual %.', num2str(-log10(obj.params.tol)), 'f', '\n'], iteration, obj.error) );
end
if (obj.params.showPrimals > 0 && mod(iteration,obj.params.showPrimals) == 1)
obj.showPrimals;
end
iteration = iteration + 1;
end
printToCmd( reverseStr,'');
end
end
end
methods (Access=protected,Hidden=true )
%protected methods that can only be accessed from class or
%subclasses. These methods are hidden!
function doCPP(obj)
%create function call
[resultCPP{1:numel(obj.x)+numel(obj.y)}] = eval('flexBoxCPP(obj);');
for i=1:numel(obj.x)
obj.x{i} = resultCPP{i};
end
for i=1:numel(obj.y)
obj.y{i} = resultCPP{numel(obj.x)+i};
end
obj.firstRun = false;
end
function doIteration(obj)
%save old
for i=1:numel(obj.xOld)
obj.xOld{i} = obj.x{i};
end
for i=1:numel(obj.yOld)
obj.yOld{i} = obj.y{i};
end
%calc yTilde and prox
for i=1:numel(obj.duals)
%input are numbers of the dual variables they shall update
%and the primal variables they correspond to
obj.duals{i}.yTilde(obj,obj.DcD{i},obj.DcP{i});
obj.duals{i}.applyProx(obj,obj.DcD{i},obj.DcP{i});
end
%primal update is x = x - K^ty
for k=1:numel(obj.duals)
dualNumbers = obj.DcD{k};
primalNumbers = obj.DcP{k};
for i=1:numel(dualNumbers)
for j=1:numel(primalNumbers)
operatorNumber = numel(primalNumbers)*(i-1) + j;
obj.x{primalNumbers(j)} = obj.x{primalNumbers(j)} - obj.params.tau{primalNumbers(j)}.*(obj.duals{k}.operatorT{operatorNumber} * obj.y{dualNumbers(i)});
end
end
end
%do overrelaxation
for i=1:numel(obj.x)
obj.xBar{i} = obj.x{i} + obj.params.theta*(obj.x{i} - obj.xOld{i});
end
end
% function adaptStepsize(obj)
% [~,p,d] = obj.calculateError;
%
% %if primal residual is massively larger than dual
% if ( p > obj.params.s*d*obj.params.delta )
% for i=1:numel(obj.x)
% obj.params.tau{i} = obj.params.tau{i} / (1-obj.params.adaptivity);%increase primal steplength
% end
% for i=1:numel(obj.y)
% obj.params.sigma{i} = obj.params.sigma{i} * (1-obj.params.adaptivity);%decrease dual steplength
% end
% obj.params.adaptivity = obj.params.adaptivity * obj.params.eta;%decrease level of adaptivity
% %if dual residual is massively larger than primal
% elseif (p < obj.params.s*d/obj.params.delta)
% for i=1:numel(obj.x)
% obj.params.tau{i} = obj.params.tau{i} * (1-obj.params.adaptivity);%decrease primal steplength
% end
% for i=1:numel(obj.y)
% obj.params.sigma{i} = obj.params.sigma{i} / (1-obj.params.adaptivity);%increase dual steplength
% end
% obj.params.adaptivity = obj.params.adaptivity * obj.params.eta;%decrease level of adaptivity
% end
%
% % p
% % d
% % obj.params.tau
% % obj.params.sigma
% end
function init(obj)
%init tau and sigma with all zero vectors
for i=1:numel(obj.x)
obj.params.tau{i} = zeros(numel(obj.x{i}),1);
end
for i=1:numel(obj.y)
obj.params.sigma{i} = zeros(numel(obj.y{i}),1);
end
%init duals
for i=1:numel(obj.duals)
obj.duals{i}.init();
%sum up tau
for j=1:numel(obj.DcP{i})
indexTmp = obj.DcP{i}(j);
obj.params.tau{ indexTmp } = obj.params.tau{ indexTmp } + obj.duals{i}.myTau{j};
end
%sum up sigma
for j=1:numel(obj.DcD{i})
indexTmp = obj.DcD{i}(j);
obj.params.sigma{ indexTmp } = obj.params.sigma{ indexTmp } + obj.duals{i}.mySigma{j};
end
end
%calculate reciprocals
for i=1:numel(obj.x)
obj.params.tau{i} = 1 ./ max(0.0001,obj.params.tau{i});
end
for i=1:numel(obj.y)
obj.params.sigma{i} = 1 ./ max(0.0001,obj.params.sigma{i});
end
end
function [res,resP,resD] = calculateError(obj)
%calculates residual in primal dual algorithm
%calculate first part
for i=1:numel(obj.x)
obj.xError{i} = (obj.x{i} - obj.xOld{i}) ./ obj.params.tau{i};
end
for i=1:numel(obj.y)
obj.yError{i} = (obj.y{i} - obj.yOld{i}) ./ obj.params.sigma{i};
end
%calculate second part
for i=1:numel(obj.duals)
obj.duals{i}.xError(obj,obj.DcD{i},obj.DcP{i});
obj.duals{i}.yError(obj,obj.DcD{i},obj.DcP{i});
end
%sum up
resP = 0;
resD = 0;
for i=1:numel(obj.x)
respTMP = sum(abs(obj.xError{i}));
resP = resP + respTMP;
%resPList{i} = respTMP;
end
resP = resP / numel(obj.x);
for i=1:numel(obj.y)
resDTMP = sum(abs(obj.yError{i}));
resD = resD + resDTMP;
%resDList{i} = respTMP;
end
resD = resD / numel(obj.y);
elements = 0;
for i=1:numel(obj.dims)
elements = elements + prod(obj.dims{i});
end
res = (resD + resP) / elements;
end
function result = checkCPP(obj)
%checkCPP
%checks if the tryCPP parameter is true and, if yes, checks if the C++ module is compiled
%if the C++ module is compiled, but cannot be found the folder specified by obj.params.relativePathToMEX, a warning is displayed
if (~obj.params.tryCPP)
CPPsupport = 0;
elseif (obj.params.tryCPP)
absPathToMEX = strcat(fileparts(mfilename('fullpath')), '/', obj.params.relativePathToMEX);
if (exist(absPathToMEX, 'dir') ~= 7) %dir is not correct. Try to find it through path
CPPsupport = 0;
disp(['Warning: relative Path to MEX-File is not correct! The default path is stored in params.relativePathToMEX']);
else
%make sure the intended MEX file is called
addpath(absPathToMEX);
end
if (exist('flexBoxCPP','file') ~= 3)
CPPsupport = 0;
disp(['Warning: C++ module is not compiled!']);
disp(['Running in MATLAB mode']);
else
CPPsupport = 1;
if (obj.params.verbose > 0)
disp(['using MEX-File: ', which('flexBoxCPP')]);
disp('Running in C++ mode');
end
end
end
result = CPPsupport;
end
end
end