diff --git a/inpainting_model.py b/inpainting_model.py index d06049162..a840ca097 100644 --- a/inpainting_model.py +++ b/inpainting_model.py @@ -133,79 +133,6 @@ def forward(self, x) : x = x * self.alpha return x + skip -# from https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting -# Contextual attention implementation is borrowed from IJCAI 2019 : "MUSICAL: Multi-Scale Image Contextual Attention Learning for Inpainting". -# Original implementation causes bad results for Pytorch 1.2+. -class GlobalLocalAttention(nn.Module): - def __init__(self, in_dim, patch_size=3, propagate_size=3, stride=1): - super(GlobalLocalAttention, self).__init__() - self.patch_size = patch_size - self.propagate_size = propagate_size - self.stride = stride - self.prop_kernels = None - self.in_dim = in_dim - self.feature_attention = GlobalAttention(in_dim) - self.patch_attention = GlobalAttentionPatch(in_dim) - - def forward(self, foreground, mask, background="same"): - ###assume the masked area has value 1 - bz, nc, w, h = foreground.size() - if background == "same": - background = foreground.clone() - mask = F.interpolate(mask, size=(w, h), mode='nearest') - background = background * (1 - mask) - foreground = self.feature_attention(foreground, background, mask) - background = F.pad(background, - [self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2]) - conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, - self.stride).contiguous().view(bz, - nc, - -1, - self.patch_size, - self.patch_size) - - mask_resized = mask.repeat(1, self.in_dim, 1, 1) - mask_resized = F.pad(mask_resized, - [self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2]) - mask_kernels_all = mask_resized.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, - self.stride).contiguous().view( - bz, - nc, - -1, - self.patch_size, - self.patch_size) - conv_kernels_all = conv_kernels_all.transpose(2, 1) - mask_kernels_all = mask_kernels_all.transpose(2, 1) - output_tensor = [] - for i in range(bz): - feature_map = foreground[i:i + 1] - - # form convolutional kernels - conv_kernels = conv_kernels_all[i] + 0.0000001 - mask_kernels = mask_kernels_all[i] - conv_kernels = self.patch_attention(conv_kernels, conv_kernels, mask_kernels) - norm_factor = torch.sum(conv_kernels ** 2, [1, 2, 3], keepdim=True) ** 0.5 - conv_kernels = conv_kernels / norm_factor - - conv_result = F.conv2d(feature_map, conv_kernels, padding=self.patch_size // 2) - if self.propagate_size != 1: - if self.prop_kernels is None: - self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size],device = mask.device) - self.prop_kernels.requires_grad = False - self.prop_kernels = self.prop_kernels - conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1)) - mm = (torch.mean(mask_kernels_all[i], dim=[1,2,3], keepdim=True)==0.0).to(torch.float32) - mm = mm.permute(1,0,2,3) - conv_result = conv_result * mm - attention_scores = F.softmax(conv_result, dim=1) - attention_scores = attention_scores * mm - - ##propagate the scores - recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride=1, - padding=self.patch_size // 2) - output_tensor.append(recovered_foreground) - return torch.cat(output_tensor, dim=0) - # from https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting class GlobalAttention(nn.Module): """ Self attention Layer""" @@ -222,9 +149,8 @@ def __init__(self, in_dim): self.gamma = nn.parameter.Parameter(torch.tensor([1.0], requires_grad=True), requires_grad=True) def forward(self, a, b, c): - m_batchsize, C, width, height = a.size() # B, C, H, W - down_rate = int(c.size(2)//width) - c = F.interpolate(c, scale_factor=1./down_rate*self.rate, mode='nearest') + m_batchsize, C, height, width = a.size() # B, C, H, W + c = F.interpolate(c, size=(height, width), mode='nearest') proj_query = self.query_conv(a).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B, C, N -> B N C proj_key = self.key_conv(b).view(m_batchsize, -1, width * height) # B, C, N feature_similarity = torch.bmm(proj_query, proj_key) # B, N, N @@ -236,47 +162,10 @@ def forward(self, a, b, c): attention = self.softmax(feature_pruning) # B, N, C feature_pruning = torch.bmm(self.value_conv(a).view(m_batchsize, -1, width * height), attention.permute(0, 2, 1)) # -. B, C, N - out = feature_pruning.view(m_batchsize, C, width, height) # B, C, H, W + out = feature_pruning.view(m_batchsize, C, height, width) # B, C, H, W out = a * c + self.gamma * (1.0 - c) * out return out - -class GlobalAttentionPatch(nn.Module): - """ Self attention Layer""" - - def __init__(self, in_dim): - super(GlobalAttentionPatch, self).__init__() - self.chanel_in = in_dim - - self.query_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) - self.key_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) - self.value_channel = ScaledWSConv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) - - self.softmax_channel = nn.Softmax(dim=-1) - self.gamma = nn.parameter.Parameter(torch.tensor([1.0], requires_grad=True), requires_grad=True) - - def forward(self, x, y, m): - ''' - Something - ''' - feature_size = list(x.size()) - # Channel attention - query_channel = self.query_channel(x).view(feature_size[0], -1, feature_size[2] * feature_size[3]) - key_channel = self.key_channel(y).view(feature_size[0], -1, feature_size[2] * feature_size[3]).permute(0, - 2, - 1) - channel_correlation = torch.bmm(query_channel, key_channel) - m_r = m.view(feature_size[0], -1, feature_size[2]*feature_size[3]) - channel_correlation = torch.bmm(channel_correlation, m_r) - energy_channel = self.softmax_channel(channel_correlation) - value_channel = self.value_channel(x).view(feature_size[0], -1, feature_size[2] * feature_size[3]) - attented_channel = (energy_channel * value_channel).view(feature_size[0], feature_size[1], - feature_size[2], - feature_size[3]) - out = x * m + self.gamma * (1.0 - m) * attented_channel - return out - - class CoarseGenerator(nn.Module) : def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.2) : super(CoarseGenerator, self).__init__() @@ -307,9 +196,15 @@ def __init__(self, in_ch = 4, out_ch = 3, ch = 32, alpha = 0.2) : self.body_conv = nn.Sequential(*self.body_conv) self.tail = nn.Sequential( + LambdaLayer(relu_nf), + GatedWSConvPadded(ch * 8, ch * 8, 3, 1), LambdaLayer(relu_nf), GatedWSConvPadded(ch * 8, ch * 4, 3, 1), LambdaLayer(relu_nf), + GatedWSConvPadded(ch * 4, ch * 4, 3, 1), + LambdaLayer(relu_nf), + GatedWSConvPadded(ch * 4, ch * 4, 3, 1), + LambdaLayer(relu_nf), GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2), LambdaLayer(relu_nf), GatedWSTransposeConvPadded(ch * 2, ch, 4, 2), @@ -347,89 +242,11 @@ def forward(self, img, mask) : x = self.tail(torch.cat([conv, attn], dim = 1)) return torch.clip(x, -1, 1) -class RefineGenerator(nn.Module) : - def __init__(self, in_ch = 5, out_ch = 3, ch = 64, alpha = 0.2) : - super(RefineGenerator, self).__init__() - - self.head = nn.Sequential( - GatedWSConvPadded(in_ch, ch, 3, stride = 1), - LambdaLayer(relu_nf), - GatedWSConvPadded(ch, ch * 2, 4, stride = 2), - LambdaLayer(relu_nf), - GatedWSConvPadded(ch * 2, ch * 4, 4, stride = 2), - ) - - self.beta = 1.0 - self.alpha = alpha - self.body_conv = [] - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 2)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 4)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 8)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta, 16)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv.append(ResBlock(ch * 4, self.alpha, self.beta)) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_conv = nn.Sequential(*self.body_conv) - - self.tail = nn.Sequential( - LambdaLayer(relu_nf), - GatedWSConvPadded(ch * 8, ch * 4, 3, 1), - LambdaLayer(relu_nf), - GatedWSTransposeConvPadded(ch * 4, ch * 2, 4, 2), - LambdaLayer(relu_nf), - GatedWSTransposeConvPadded(ch * 2, ch, 4, 2), - LambdaLayer(relu_nf), - GatedWSConvPadded(ch, out_ch, 3, stride = 1), - ) - - self.beta = 1.0 - - self.body_attn_1 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_2 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_3 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_4 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_attn = GlobalLocalAttention(in_dim = ch * 4) - self.body_attn_5 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_6 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - self.body_attn_7 = ResBlock(ch * 4, self.alpha, self.beta) - self.beta = (self.beta ** 2 + self.alpha ** 2) ** 0.5 - - def forward(self, img, img_coarse, mask) : - x = img_coarse * mask + img * (1. - mask) - x = torch.cat([mask, x], dim = 1) - x = self.head(x) - attn = self.body_attn_1(x) - attn = self.body_attn_2(attn) - attn = self.body_attn_3(attn) - attn = self.body_attn_4(attn) - attn = self.body_attn_attn(attn, mask) - attn = self.body_attn_5(attn) - attn = self.body_attn_6(attn) - attn = self.body_attn_7(attn) - conv = self.body_conv(x) - x = self.tail(torch.cat([conv, attn], dim = 1)) - return torch.clip(x, -1, 1) - class InpaintingVanilla(nn.Module): def __init__(self): super(InpaintingVanilla, self).__init__() self.coarse_generator = CoarseGenerator(4, 3, 32) - self.fine_generator = RefineGenerator(4, 3, 64) def forward(self, x, mask): x_stage1 = self.coarse_generator(x, mask) - x_stage2 = self.fine_generator(x, x_stage1, mask) - return x_stage1, x_stage2 + return x_stage1 diff --git a/text_mask_utils.py b/text_mask_utils.py index 9406efe41..1f3da49b7 100755 --- a/text_mask_utils.py +++ b/text_mask_utils.py @@ -10,11 +10,6 @@ from collections import defaultdict from scipy.optimize import linear_sum_assignment -import argparse -parser = argparse.ArgumentParser(description='Generate text bboxes given a image file') -parser.add_argument('--image', default='', type=str, help='Image file') -args = parser.parse_args() - COLOR_RANGE_SIGMA = 1.5 # how many stddev away is considered the same color def save_rgb(fn, img) : diff --git a/translate_demo.py b/translate_demo.py index 62cf39144..f6c831005 100755 --- a/translate_demo.py +++ b/translate_demo.py @@ -27,6 +27,7 @@ parser.add_argument('--link_threshold', default=0.4, type=float, help='link_threshold') parser.add_argument('--low_text', default=0.4, type=float, help='low_text') args = parser.parse_args() +print(args) import unicodedata @@ -419,6 +420,11 @@ def resize_keep_aspect(img, size) : return cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_LINEAR_EXACT) def run_inpainting(model_inpainting, img, mask, max_image_size = 1024, pad_size = 4) : + img_original = np.copy(img) + mask_original = np.copy(mask) + mask_original[mask_original < 127] = 0 + mask_original[mask_original >= 127] = 1 + mask_original = mask_original[:, :, None] if not args.use_inpainting : img = np.copy(img) img[mask > 0] = np.array([255, 255, 255], np.uint8) @@ -439,17 +445,21 @@ def run_inpainting(model_inpainting, img, mask, max_image_size = 1024, pad_size if new_h != h or new_w != w : img = cv2.resize(img, (new_w, new_h), interpolation = cv2.INTER_LINEAR_EXACT) mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR_EXACT) + print(f'Inpainting resolution: {new_w}x{new_h}') img_torch = torch.from_numpy(img).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0 mask_torch = torch.from_numpy(mask).unsqueeze_(0).unsqueeze_(0).float() / 255.0 + mask_torch[mask_torch < 0.5] = 0 + mask_torch[mask_torch >= 0.5] = 1 if args.use_cuda : img_torch = img_torch.cuda() mask_torch = mask_torch.cuda() with torch.no_grad() : - _, img_inpainted_torch = model_inpainting(img_torch, mask_torch) + img_torch *= (1 - mask_torch) + img_inpainted_torch = model_inpainting(img_torch, mask_torch) img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8) if new_h != height or new_w != width : img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR_EXACT) - return img_inpainted + return img_inpainted * mask_original + img_original * (1 - mask_original) from baidutrans import Translator as baidu_trans baidu_translator = baidu_trans() @@ -504,7 +514,7 @@ def main() : img_to_overlay = np.copy(img_resized) ratio_h = ratio_w = 1 / target_ratio img_resized = imgproc.normalizeMeanVariance(img_resized) - print(img_resized.shape) + print(f'Detection resolution: {img_resized.shape[1]}x{img_resized.shape[0]}') print(' -- Running text detection') rscore, ascore, mask = run_detect(model_detect, img_resized) overlay = imgproc.cvt2HeatmapImg(rscore + ascore) @@ -579,7 +589,7 @@ def main() : final_mask = cv2.resize(final_mask, (img.shape[1], img.shape[0]), interpolation = cv2.INTER_LINEAR) print(' -- Running inpainting') # run inpainting - img_inpainted = run_inpainting(model_inpainting, img, final_mask) + img_inpainted = run_inpainting(model_inpainting, img, final_mask, args.inpainting_size) print(' -- Translating') # translate text region texts texts = '\n'.join([r.text for r in text_regions]) @@ -618,7 +628,7 @@ def main() : cv2.imwrite('result/bbox_unfiltered.png', img_bbox_all) cv2.imwrite('result/overlay.png', cv2.cvtColor(overlay_image(img_to_overlay, cv2.resize(overlay, (img_resized.shape[1], img_resized.shape[0]), interpolation=cv2.INTER_LINEAR)), cv2.COLOR_RGB2BGR)) cv2.imwrite('result/mask.png', final_mask) - cv2.imwrite('result/masked.png', cv2.cvtColor(img_inpainted, cv2.COLOR_RGB2BGR)) + cv2.imwrite('result/inpainted.png', cv2.cvtColor(img_inpainted, cv2.COLOR_RGB2BGR)) cv2.imwrite('result/final.png', cv2.cvtColor(img_canvas, cv2.COLOR_RGB2BGR)) if __name__ == '__main__':