Skip to content

Commit

Permalink
update inpainting models
Browse files Browse the repository at this point in the history
  • Loading branch information
zyddnys committed Mar 4, 2021
1 parent fa47a4e commit 9b0eb78
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 203 deletions.
203 changes: 10 additions & 193 deletions inpainting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,79 +133,6 @@ def forward(self, x) :
x = x * self.alpha
return x + skip

# from https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
# Contextual attention implementation is borrowed from IJCAI 2019 : "MUSICAL: Multi-Scale Image Contextual Attention Learning for Inpainting".
# Original implementation causes bad results for Pytorch 1.2+.
class GlobalLocalAttention(nn.Module):
def __init__(self, in_dim, patch_size=3, propagate_size=3, stride=1):
super(GlobalLocalAttention, self).__init__()
self.patch_size = patch_size
self.propagate_size = propagate_size
self.stride = stride
self.prop_kernels = None
self.in_dim = in_dim
self.feature_attention = GlobalAttention(in_dim)
self.patch_attention = GlobalAttentionPatch(in_dim)

def forward(self, foreground, mask, background="same"):
###assume the masked area has value 1
bz, nc, w, h = foreground.size()
if background == "same":
background = foreground.clone()
mask = F.interpolate(mask, size=(w, h), mode='nearest')
background = background * (1 - mask)
foreground = self.feature_attention(foreground, background, mask)
background = F.pad(background,
[self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2])
conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size,
self.stride).contiguous().view(bz,
nc,
-1,
self.patch_size,
self.patch_size)

mask_resized = mask.repeat(1, self.in_dim, 1, 1)
mask_resized = F.pad(mask_resized,
[self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2])
mask_kernels_all = mask_resized.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size,
self.stride).contiguous().view(
bz,
nc,
-1,
self.patch_size,
self.patch_size)
conv_kernels_all = conv_kernels_all.transpose(2, 1)
mask_kernels_all = mask_kernels_all.transpose(2, 1)
output_tensor = []
for i in range(bz):
feature_map = foreground[i:i + 1]

# form convolutional kernels
conv_kernels = conv_kernels_all[i] + 0.0000001
mask_kernels = mask_kernels_all[i]
conv_kernels = self.patch_attention(conv_kernels, conv_kernels, mask_kernels)
norm_factor = torch.sum(conv_kernels ** 2, [1, 2, 3], keepdim=True) ** 0.5
conv_kernels = conv_kernels / norm_factor

conv_result = F.conv2d(feature_map, conv_kernels, padding=self.patch_size // 2)
if self.propagate_size != 1:
if self.prop_kernels is None:
self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size],device = mask.device)
self.prop_kernels.requires_grad = False
self.prop_kernels = self.prop_kernels
conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))
mm = (torch.mean(mask_kernels_all[i], dim=[1,2,3], keepdim=True)==0.0).to(torch.float32)
mm = mm.permute(1,0,2,3)
conv_result = conv_result * mm
attention_scores = F.softmax(conv_result, dim=1)
attention_scores = attention_scores * mm

##propagate the scores
recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride=1,
padding=self.patch_size // 2)
output_tensor.append(recovered_foreground)
return torch.cat(output_tensor, dim=0)

# from https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting
class GlobalAttention(nn.Module):
""" Self attention Layer"""
Expand All @@ -222,9 +149,8 @@ def __init__(self, in_dim):
self.gamma = nn.parameter.Parameter(torch.tensor([1.0], requires_grad=True), requires_grad=True)

def forward(self, a, b, c):
m_batchsize, C, width, height = a.size() # B, C, H, W
down_rate = int(c.size(2)//width)
c = F.interpolate(c, scale_factor=1./down_rate*self.rate, mode='nearest')
m_batchsize, C, height, width = a.size() # B, C, H, W
c = F.interpolate(c, size=(height, width), mode='nearest')
proj_query = self.query_conv(a).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B, C, N -> B N C
proj_key = self.key_conv(b).view(m_batchsize, -1, width * height) # B, C, N
feature_similarity = torch.bmm(proj_query, proj_key) # B, N, N
Expand All @@ -236,47 +162,10 @@ def forward(self, a, b, c):
attention = self.softmax(feature_pruning) # B, N, C
feature_pruning = torch.bmm(self.value_conv(a).view(m_batchsize, -1, width * height),
attention.permute(0, 2, 1)) # -. B, C, N
out = feature_pruning.view(m_batchsize, C, width, height) # B, C, H, W
out = feature_pruning.view(m_batchsize, C, height, width) # B, C, H, W
out = a * c + self.gamma * (1.0 - c) * out
return out


class GlobalAttentionPatch(nn.Module):
""" Self attention Layer"""

def __init__(self, in_dim):
super(GlobalAttentionPatch, self).__init__()
self.chanel_in = in_dim

self.query_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.key_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.value_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)

self.softmax_channel = nn.Softmax(dim=-1)
self.gamma = nn.parameter.Parameter(torch.tensor([1.0], requires_grad=True), requires_grad=True)

def forward(self, x, y, m):
'''
Something
'''
feature_size = list(x.size())
# Channel attention
query_channel = self.query_channel(x).view(feature_size[0], -1, feature_size[2] * feature_size[3])
key_channel = self.key_channel(y).view(feature_size[0], -1, feature_size[2] * feature_size[3]).permute(0,
2,
1)
channel_correlation = torch.bmm(query_channel, key_channel)
m_r = m.view(feature_size[0], -1, feature_size[2]*feature_size[3])
channel_correlation = torch.bmm(channel_correlation, m_r)
energy_channel = self.softmax_channel(channel_correlation)
value_channel = self.value_channel(x).view(feature_size[0], -1, feature_size[2] * feature_size[3])
attented_channel = (energy_channel * value_channel).view(feature_size[0], feature_size[1],
feature_size[2],
feature_size[3])
out = x * m + self.gamma * (1.0 - m) * attented_channel
return out


class CoarseGenerator(nn.Module) :
def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.2) :
super(CoarseGenerator, self).__init__()
Expand Down Expand Up @@ -307,9 +196,15 @@ def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.2) :
self.body_conv = nn.Sequential(*self.body_conv)

self.tail = nn.Sequential(
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 8, ch * 8, 3, 1),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 8, ch * 4, 3, 1),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 4, ch * 4, 3, 1),
LambdaLayer(relu_nf),
GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2),
LambdaLayer(relu_nf),
GatedWSTransposeConvPadded(ch * 2, ch, 4, 2),
Expand Down Expand Up @@ -347,89 +242,11 @@ def forward(self, img, mask) :
x = self.tail(torch.cat([conv, attn], dim = 1))
return torch.clip(x, -1, 1)

class RefineGenerator(nn.Module) :
def __init__(self, in_ch = 5, out_ch = 3, ch = 64, alpha = 0.2) :
super(RefineGenerator, self).__init__()

self.head = nn.Sequential(
GatedWSConvPadded(in_ch, ch, 3, stride = 1),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch, ch * 2, 4, stride = 2),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 2, ch * 4, 4, stride = 2),
)

self.beta = 1.0
self.alpha = alpha
self.body_conv = []
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 2))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 4))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 8))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 16))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta))
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_conv = nn.Sequential(*self.body_conv)

self.tail = nn.Sequential(
LambdaLayer(relu_nf),
GatedWSConvPadded(ch * 8, ch * 4, 3, 1),
LambdaLayer(relu_nf),
GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2),
LambdaLayer(relu_nf),
GatedWSTransposeConvPadded(ch * 2, ch, 4, 2),
LambdaLayer(relu_nf),
GatedWSConvPadded(ch, out_ch, 3, stride = 1),
)

self.beta = 1.0

self.body_attn_1 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_2 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_3 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_4 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_attn = GlobalLocalAttention(in_dim = ch * 4)
self.body_attn_5 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_6 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5
self.body_attn_7 = ResBlock(ch * 4, self.alpha, self.beta)
self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5

def forward(self, img, img_coarse, mask) :
x = img_coarse * mask + img * (1. - mask)
x = torch.cat([mask, x], dim = 1)
x = self.head(x)
attn = self.body_attn_1(x)
attn = self.body_attn_2(attn)
attn = self.body_attn_3(attn)
attn = self.body_attn_4(attn)
attn = self.body_attn_attn(attn, mask)
attn = self.body_attn_5(attn)
attn = self.body_attn_6(attn)
attn = self.body_attn_7(attn)
conv = self.body_conv(x)
x = self.tail(torch.cat([conv, attn], dim = 1))
return torch.clip(x, -1, 1)

class InpaintingVanilla(nn.Module):
def __init__(self):
super(InpaintingVanilla, self).__init__()
self.coarse_generator = CoarseGenerator(4, 3, 32)
self.fine_generator = RefineGenerator(4, 3, 64)

def forward(self, x, mask):
x_stage1 = self.coarse_generator(x, mask)
x_stage2 = self.fine_generator(x, x_stage1, mask)
return x_stage1, x_stage2
return x_stage1
5 changes: 0 additions & 5 deletions text_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
from collections import defaultdict
from scipy.optimize import linear_sum_assignment

import argparse
parser = argparse.ArgumentParser(description='Generate text bboxes given a image file')
parser.add_argument('--image', default='', type=str, help='Image file')
args = parser.parse_args()

COLOR_RANGE_SIGMA = 1.5 # how many stddev away is considered the same color

def save_rgb(fn, img) :
Expand Down
20 changes: 15 additions & 5 deletions translate_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
parser.add_argument('--link_threshold', default=0.4, type=float, help='link_threshold')
parser.add_argument('--low_text', default=0.4, type=float, help='low_text')
args = parser.parse_args()
print(args)

import unicodedata

Expand Down Expand Up @@ -419,6 +420,11 @@ def resize_keep_aspect(img, size) :
return cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_LINEAR_EXACT)

def run_inpainting(model_inpainting, img, mask, max_image_size = 1024, pad_size = 4) :
img_original = np.copy(img)
mask_original = np.copy(mask)
mask_original[mask_original < 127] = 0
mask_original[mask_original >= 127] = 1
mask_original = mask_original[:, :, None]
if not args.use_inpainting :
img = np.copy(img)
img[mask > 0] = np.array([255, 255, 255], np.uint8)
Expand All @@ -439,17 +445,21 @@ def run_inpainting(model_inpainting, img, mask, max_image_size = 1024, pad_size
if new_h != h or new_w != w :
img = cv2.resize(img, (new_w, new_h), interpolation = cv2.INTER_LINEAR_EXACT)
mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR_EXACT)
print(f'Inpainting resolution: {new_w}x{new_h}')
img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0
mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0
mask_torch[mask_torch < 0.5] = 0
mask_torch[mask_torch >= 0.5] = 1
if args.use_cuda :
img_torch = img_torch.cuda()
mask_torch = mask_torch.cuda()
with torch.no_grad() :
_, img_inpainted_torch = model_inpainting(img_torch, mask_torch)
img_torch *= (1 - mask_torch)
img_inpainted_torch = model_inpainting(img_torch, mask_torch)
img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)
if new_h != height or new_w != width :
img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR_EXACT)
return img_inpainted
return img_inpainted * mask_original + img_original * (1 - mask_original)

from baidutrans import Translator as baidu_trans
baidu_translator = baidu_trans()
Expand Down Expand Up @@ -504,7 +514,7 @@ def main() :
img_to_overlay = np.copy(img_resized)
ratio_h = ratio_w = 1 / target_ratio
img_resized = imgproc.normalizeMeanVariance(img_resized)
print(img_resized.shape)
print(f'Detection resolution: {img_resized.shape[1]}x{img_resized.shape[0]}')
print(' -- Running text detection')
rscore, ascore, mask = run_detect(model_detect, img_resized)
overlay = imgproc.cvt2HeatmapImg(rscore + ascore)
Expand Down Expand Up @@ -579,7 +589,7 @@ def main() :
final_mask = cv2.resize(final_mask, (img.shape[1], img.shape[0]), interpolation = cv2.INTER_LINEAR)
print(' -- Running inpainting')
# run inpainting
img_inpainted = run_inpainting(model_inpainting, img, final_mask)
img_inpainted = run_inpainting(model_inpainting, img, final_mask, args.inpainting_size)
print(' -- Translating')
# translate text region texts
texts = '\n'.join([r.text for r in text_regions])
Expand Down Expand Up @@ -618,7 +628,7 @@ def main() :
cv2.imwrite('result/bbox_unfiltered.png', img_bbox_all)
cv2.imwrite('result/overlay.png', cv2.cvtColor(overlay_image(img_to_overlay, cv2.resize(overlay, (img_resized.shape[1], img_resized.shape[0]), interpolation=cv2.INTER_LINEAR)), cv2.COLOR_RGB2BGR))
cv2.imwrite('result/mask.png', final_mask)
cv2.imwrite('result/masked.png', cv2.cvtColor(img_inpainted, cv2.COLOR_RGB2BGR))
cv2.imwrite('result/inpainted.png', cv2.cvtColor(img_inpainted, cv2.COLOR_RGB2BGR))
cv2.imwrite('result/final.png', cv2.cvtColor(img_canvas, cv2.COLOR_RGB2BGR))

if __name__ == '__main__':
Expand Down

0 comments on commit 9b0eb78

Please sign in to comment.