-
Notifications
You must be signed in to change notification settings - Fork 6
/
partialconv_unet.py
151 lines (126 loc) · 6.41 KB
/
partialconv_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
else:
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask=None):
if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
self.last_size = (input.data.shape[2], input.data.shape[3])
with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
if mask is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
self.update_mask.to(input)
self.mask_ratio.to(input)
raw_out = super(PartialConv2d, self).forward(input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class PartialConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding=0, dilation=1, bias=True, return_mask=True, normalize=True):
super().__init__()
self.conv = PartialConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, return_mask=return_mask)
self.elu = nn.ELU(inplace=True)
self.normalize = normalize
if self.normalize:
self.normalize = nn.InstanceNorm2d(out_channels)
def forward(self, feat, fmask):
feat, fmask = self.conv(feat, fmask)
feat = self.elu(feat)
if self.normalize:
feat = self.normalize(feat)
return feat, fmask
class PartialConvUnet(nn.Module):
def __init__(self):
super().__init__()
# encoder blocks
self.enc_0 = PartialConvBlock(3, 64, 3, 1, 1, normalize=False)
self.enc_1 = PartialConvBlock(64, 128, 3, 2, 1)
self.enc_2 = PartialConvBlock(128, 128, 3, 1, 1)
self.enc_3 = PartialConvBlock(128, 128, 3, 1, 1)
self.enc_4 = PartialConvBlock(128, 256, 3, 2, 1)
# dilation blocks
self.dil_0 = PartialConvBlock(256, 256, 3, 1, 2, dilation=2)
self.dil_1 = PartialConvBlock(256, 256, 3, 1, 2, dilation=2)
self.dil_2 = PartialConvBlock(256, 256, 3, 1, 2, dilation=2)
self.dil_3 = PartialConvBlock(256, 256, 3, 1, 2, dilation=2)
# decoder blocks
self.dec_5 = PartialConvBlock(256, 256, 3, 1, 1)
self.dec_4 = PartialConvBlock(256, 256, 3, 1, 1)
self.dec_3 = PartialConvBlock(256, 128, 3, 1, 1)
self.dec_2 = PartialConvBlock(128, 128, 3, 1, 1)
self.dec_1 = PartialConvBlock(128, 64, 3, 1, 1)
self.dec_0 = PartialConvBlock(64, 32, 3, 1, 1, normalize=False)
self.post_dec = nn.Sequential(nn.Conv2d(32, 3, 1, 1, 0),
nn.Tanh())
def forward(self, occl_img, mask):
# remove pixel information in the masked area.
feat, fmask = occl_img * mask, mask
feat, fmask = self.enc_0(feat, fmask)
feat, fmask = self.enc_1(feat, fmask)
feat, fmask = self.enc_2(feat, fmask)
feat, fmask = self.enc_3(feat, fmask)
feat, fmask = self.enc_4(feat, fmask)
feat, fmask = self.dil_0(feat, fmask)
feat, fmask = self.dil_1(feat, fmask)
feat, fmask = self.dil_2(feat, fmask)
feat, fmask = self.dil_3(feat, fmask)
feat, fmask = self.dec_5(feat, fmask)
feat, fmask = self.dec_4(feat, fmask)
feat, fmask = F.upsample(feat, scale_factor=2, mode='nearest'), F.upsample(fmask, scale_factor=2, mode='nearest')
feat, fmask = self.dec_3(feat, fmask)
feat, fmask = self.dec_2(feat, fmask)
feat, fmask = F.upsample(feat, scale_factor=2, mode='nearest'), F.upsample(fmask, scale_factor=2, mode='nearest')
feat, fmask = self.dec_1(feat, fmask)
feat, fmask = self.dec_0(feat, fmask)
feat = self.post_dec(feat)
return feat
if __name__ == "__main__":
img = torch.randn(1, 3, 128, 128)
mask = torch.randn(1, 1, 128, 128)
model = PartialConvUnet()
with torch.no_grad():
out = model(img, mask)
print('in', img.size())
print('out', out.size())