forked from rtqichen/style-swap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstyle-swap.lua
237 lines (186 loc) · 6.89 KB
/
style-swap.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
require 'torch'
lapp = require 'pl.lapp'
opt = lapp[[
== Basic options ==
--style (default '') File path to target image for style
--content (default '') File path to target image for content
--contentBatch (default '') Directory path to target images for content
== More Options ==
--maxContentSize (default 640) Maximum height and width for content image
--maxStyleSize (default 512) Maximum height and width for style image
--save (default output) Directory to save in
--saveOriginal If set, saves the original image as well
== Advanced ==
--gpu (default 0)
--patchSize (default 3) Patch size for style swap [Higher = More Style Texture]
--patchStride (default 1) Patch stride for style swap operation
--pooling (default 'avg') One of [avg|max]
--numSwap (default 1) Number of times to perform the style swap operation [Higher = More Style Contrast]
--decoder (default '') Path to a trained decoder
--optim If set, decoder is only used for initialization and optimization still occurs
--learningRate (default 0.05) Learning rate for optimization
--init (default 'content') How to initialize the generated image [random|content]
--tv (default 1e-7) Weight for TV loss [Higher = Blur]
--layer (default 'relu3_1') VGG layer to style swap on
--optimIter (default 100) Number of iterations for optimization
--printEvery (default 50) Print loss every so iterations
--saveLoss If set, saves a table of loss values.
]]
print(opt)
if opt.style == '' then
error('--style must be provided.')
end
if opt.content == '' and opt.contentBatch == '' then
error('--content or --contentBatch must be provided.')
end
if not paths.filep(opt.style) then
error('Style image ' .. opt.style .. ' does not exist.')
end
if opt.content ~= '' and not paths.filep(opt.content) then
error('Content image ' .. opt.content .. ' does not exist.')
end
if opt.contentBatch ~= '' and not paths.dirp(opt.contentBatch) then
error('Content directory ' .. opt.contentBatch .. ' does not exist.')
end
if opt.decoder ~= '' and not paths.filep(opt.decoder) then
error('Decoder ' .. opt.decoder .. ' does not exist.')
end
print('Loading Lua modules...')
require 'nn'
require 'cudnn'
require 'cunn'
require 'loadcaffe'
require 'lib/ArtisticStyleLossCriterion'
require 'image'
require 'lib/ImageLoader'
require 'lib/NonparametricPatchAutoencoderFactory'
require 'lib/InstanceNormalization'
require 'optim'
require 'lib/MaxCoord'
require 'paths'
require 'image'
cutorch.setDevice(opt.gpu+1)
vgg = loadcaffe.load('models/VGG_ILSVRC_19_layers_deploy.prototxt', 'models/VGG_ILSVRC_19_layers.caffemodel', 'nn')
for i=46,37,-1 do
vgg:remove(i)
end
layers = {}
layers.content = {opt.layer}
weights = {}
weights.content = 1
weights.tv = opt.tv
use_avg_pooling = opt.pooling == 'avg'
criterion = nn.ArtisticStyleLossCriterion(vgg, layers, use_avg_pooling, weights, targets, false)
vgg = nil
collectgarbage()
print(criterion.net)
if opt.decoder ~= '' then
dec = torch.load(opt.decoder)
decoder = nn.Sequential()
decoder:add(nn.Unsqueeze(1)) -- add batch dim
decoder:add(dec)
decoder:add(nn.Squeeze(1)) -- remove batch dim
decoder:cuda()
collectgarbage()
print(dec)
end
local orig_window
local optim_window
local decoder_window
local optim_losses = {}
function synth(img)
local x = img:clone()
local sgdState = {
learningRate = opt.learningRate
}
local losses = torch.Tensor(opt.optimIter)
for i=1,opt.optimIter do
function feval(x)
local disp = x:clamp(0,1)
local loss = criterion:forward(x)
local loss_grad = criterion:backward(x)
losses[i] = loss
if i % opt.printEvery == 0 then
print(string.format('%d, %e',i,loss))
end
return loss, loss_grad:view(-1)
end
optim.adam(feval, x, sgdState)
end
optim_losses[#optim_losses+1] = losses
print('Done')
return x
end
style_img = image.load(opt.style)
style_img = image.scale(style_img, opt.maxStyleSize)
style_img = style_img:cuda()
criterion.targets = true -- override behavior
criterion.net:forward(style_img)
style_latent = criterion.net.output:clone()
swap_enc, swap_dec = NonparametricPatchAutoencoderFactory.buildAutoencoder(style_latent, opt.patchSize, opt.patchStride, false, false, true)
swap = nn.Sequential()
swap:add(swap_enc)
swap:add(nn.MaxCoord())
swap:add(swap_dec)
swap:evaluate()
swap:cuda()
print(swap)
function swapTransfer(img, name)
if opt.saveOriginal then
image.save(opt.save .. '/' .. name, img)
end
img = img:cuda()
criterion:unsetTargets()
criterion.net:forward(img)
img_latent = criterion.net.output:clone()
criterion.net:clearState()
swap_latent = swap:forward(img_latent):clone()
swap:clearState()
if opt.decoder ~= '' then
x = decoder:forward(swap_latent):clone()
decoder:clearState()
criterion.net.modules[#criterion.net.modules]:setTarget(swap_latent)
if not opt.optim then
local dec_loss = criterion:forward(x)
optim_losses[#optim_losses+1] = torch.Tensor{dec_loss}
else
x = synth(x)
end
else
local nUpsample = string.match(opt.layer, "(%d)_%d") -1
local H,W = swap_latent:size(2)*math.pow(2,nUpsample), swap_latent:size(3)*math.pow(2,nUpsample)
img = image.crop(img:double(), 0,0, W,H):cuda()
if opt.init == 'random' then img:uniform() end
criterion.net.modules[#criterion.net.modules]:setTarget(swap_latent)
x = synth(img)
end
ext = paths.extname(name)
image.save(opt.save .. '/' .. string.gsub(name, '.' .. ext, '_stylized.' .. ext), x)
criterion.net:clearState()
return x
end
print('Creating save folder at ' .. opt.save)
paths.mkdir(opt.save)
if opt.content ~= '' then
img = image.load(opt.content)
local H,W = img:size(2), img:size(3)
if H > opt.maxContentSize or W > opt.maxContentSize then
img = image.scale(img, opt.maxContentSize)
end
name = paths.basename(opt.content)
for i=1,opt.numSwap do
img = swapTransfer(img, name)
end
else
imageLoader = ImageLoader(opt.contentBatch)
imageLoader:setMaximumSize(opt.maxContentSize)
for i=1, #imageLoader.files do
img,name = imageLoader:next()
for i=1,opt.numSwap do
img = swapTransfer(img, name)
end
end
end
if opt.saveLoss then
torch.save(opt.save .. '/loss.t7', optim_losses)
end