From 1e4954e32ae0cefb5c86c16785bf5dbd4a890ba1 Mon Sep 17 00:00:00 2001 From: zyddnys Date: Thu, 6 May 2021 11:33:03 -0400 Subject: [PATCH] alpha-v2.2.1 --- DBHead.py | 73 ++++++++++++++++++ DBNet_resnet101.py | 152 ++++++++++++++++++++++++++++++++++++++ README.md | 62 +++------------- README_EN.md | 60 +++------------ dbnet_utils.py | 179 +++++++++++++++++++++++++++++++++++++++++++++ model_ocr.py | 10 +-- translate_demo.py | 64 ++++++++-------- 7 files changed, 468 insertions(+), 132 deletions(-) create mode 100644 DBHead.py create mode 100644 DBNet_resnet101.py create mode 100644 dbnet_utils.py diff --git a/DBHead.py b/DBHead.py new file mode 100644 index 000000000..d818e4dc3 --- /dev/null +++ b/DBHead.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# @Time : 2019/12/4 14:54 +# @Author : zhoujun +import torch +from torch import nn + +class DBHead(nn.Module): + def __init__(self, in_channels, out_channels, k = 50): + super().__init__() + self.k = k + self.binarize = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), + nn.BatchNorm2d(in_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 4, 2, 1), + nn.BatchNorm2d(in_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, 1, 4, 2, 1), + ) + self.binarize.apply(self.weights_init) + + self.thresh = self._init_thresh(in_channels) + self.thresh.apply(self.weights_init) + + def forward(self, x): + shrink_maps = self.binarize(x) + threshold_maps = self.thresh(x) + if self.training: + binary_maps = self.step_function(shrink_maps.sigmoid(), threshold_maps) + y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1) + else: + y = torch.cat((shrink_maps, threshold_maps), dim=1) + return y + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False): + in_channels = inner_channels + if serial: + in_channels += 1 + self.thresh = nn.Sequential( + nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), + nn.Sigmoid()) + return self.thresh + + def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False): + if smooth: + inter_out_channels = out_channels + if out_channels == 1: + inter_out_channels = in_channels + module_list = [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] + if out_channels == 1: + module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True)) + return nn.Sequential(module_list) + else: + return nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1) + + def step_function(self, x, y): + return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) diff --git a/DBNet_resnet101.py b/DBNet_resnet101.py new file mode 100644 index 000000000..4696cd502 --- /dev/null +++ b/DBNet_resnet101.py @@ -0,0 +1,152 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision.models import resnet101 + +import DBHead +import einops + +class ImageMultiheadSelfAttention(nn.Module) : + def __init__(self, planes): + super(ImageMultiheadSelfAttention, self).__init__() + self.attn = nn.MultiheadAttention(planes, 8) + def forward(self, x) : + res = x + n, c, h, w = x.shape + x = einops.rearrange(x, 'n c h w -> (h w) n c') + x = self.attn(x, x, x)[0] + x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w) + return res + x + +class double_conv(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256): + super(double_conv, self).__init__() + self.planes = planes + # down = None + # if stride > 1 : + # down = nn.Sequential( + # nn.AvgPool2d(2, 2), + # nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion) + # ) + self.down = None + if stride > 1 : + self.down = nn.AvgPool2d(2,stride=2) + self.conv = nn.Sequential( + nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False), + nn.BatchNorm2d(mid_ch), + nn.ReLU(inplace=True), + #Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d), + nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + if self.down is not None : + x = self.down(x) + x = self.conv(x) + return x + +class double_conv_up(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256): + super(double_conv_up, self).__init__() + self.planes = planes + # down = None + # if stride > 1 : + # down = nn.Sequential( + # nn.AvgPool2d(2, 2), + # nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion) + # ) + self.down = None + if stride > 1 : + self.down = nn.AvgPool2d(2,stride=2) + self.conv = nn.Sequential( + nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False), + nn.BatchNorm2d(mid_ch), + nn.ReLU(inplace=True), + #Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d), + nn.Conv2d(mid_ch, mid_ch, kernel_size=3, stride = 1, padding=1, bias=False), + nn.BatchNorm2d(mid_ch), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + if self.down is not None : + x = self.down(x) + x = self.conv(x) + return x + +class TextDetection(nn.Module) : + def __init__(self, pretrained=None) : + super(TextDetection, self).__init__() + self.backbone = resnet101(pretrained=True if pretrained else False) + + self.conv_db = DBHead.DBHead(64, 0) + + self.conv_mask = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(32, 1, kernel_size=1), + nn.Sigmoid() + ) + + self.down_conv1 = double_conv(0, 512, 512, 2) + self.down_conv2 = double_conv(0, 512, 512, 2) + self.down_conv3 = double_conv(0, 512, 512, 2) + + self.upconv1 = double_conv_up(0, 512, 256) + self.upconv2 = double_conv_up(256, 512, 256) + self.upconv3 = double_conv_up(256, 512, 256) + self.upconv4 = double_conv_up(256, 512, 256, planes = 128) + self.upconv5 = double_conv_up(256, 256, 128, planes = 64) + self.upconv6 = double_conv_up(128, 128, 64, planes = 32) + self.upconv7 = double_conv_up(64, 64, 64, planes = 16) + + self.proj_h4 = nn.Conv2d(64 * 4, 64, 1) + self.proj_h8 = nn.Conv2d(128 * 4, 128, 1) + self.proj_h16 = nn.Conv2d(256 * 4, 256, 1) + self.proj_h32 = nn.Conv2d(512 * 4, 512, 1) + + def forward(self, x) : + x = self.backbone.conv1(x) + x = self.backbone.bn1(x) + x = self.backbone.relu(x) + x = self.backbone.maxpool(x) # 64@384 + + h4 = self.backbone.layer1(x) # 64@384 + h8 = self.backbone.layer2(h4) # 128@192 + h16 = self.backbone.layer3(h8) # 256@96 + h32 = self.backbone.layer4(h16) # 512@48 + + h4 = self.proj_h4(h4) + h8 = self.proj_h8(h8) + h16 = self.proj_h16(h16) + h32 = self.proj_h32(h32) + + h64 = self.down_conv1(h32) # 512@24 + h128 = self.down_conv2(h64) # 512@12 + h256 = self.down_conv3(h128) # 512@6 + + up256 = self.upconv1(h256) # 128@12 + up128 = self.upconv2(torch.cat([up256, h128], dim = 1)) # 64@24 + up64 = self.upconv3(torch.cat([up128, h64], dim = 1)) # 128@48 + up32 = self.upconv4(torch.cat([up64, h32], dim = 1)) # 64@96 + up16 = self.upconv5(torch.cat([up32, h16], dim = 1)) # 128@192 + up8 = self.upconv6(torch.cat([up16, h8], dim = 1)) # 64@384 + up4 = self.upconv7(torch.cat([up8, h4], dim = 1)) # 64@768 + + return self.conv_db(up8), self.conv_mask(up4) + +if __name__ == '__main__' : + device = torch.device("cuda:0") + net = TextDetection().to(device) + img = torch.randn(2, 3, 1024, 1024).to(device) + db, seg = net(img) + print(db.shape) + print(seg.shape) diff --git a/README.md b/README.md index e516c561a..985e8c86c 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,17 @@ https://touhou.ai/imgtrans/ Note this may not work sometimes due to stupid google gcp kept restarting my instance. In that case you can wait for me to restart the service, which may take up to 24 hrs. # English README [README_EN.md](README_EN.md) -# 关于新模型 -新模型使用DBNet,正在训练,将更好支持英文识别。 \ -新的图片修复将去掉attention以减少显存占用。 \ -预计一到两周左右出来。 +# Changelogs +### 2021-05-06 +1. 检测模型更新为基于ResNet101的DBNet +2. OCR模型更新更深 +3. 默认检测分辨率增加到2048 + +注意这个版本除了英文检测稍微好一些,其他方面都不如之前版本 +### 2021-03-04 +1. 添加图片修补模型 +### 2021-02-17 +1. 初步版本发布 # 一键翻译各类图片内文字 针对群内、各个图站上大量不太可能会有人去翻译的图片设计,让我这种日语小白能够勉强看懂图片\ 主要支持日语,不过也能识别汉语和小写英文 \ @@ -15,7 +22,7 @@ Note this may not work sometimes due to stupid google gcp kept restarting my ins # 使用说明 1. clone这个repo -2. [下载](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2)ocr.ckpt、detect.ckpt和inpainting.ckpt,放到这个repo的根目录下 +2. [下载](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2.1)ocr.ckpt、detect.ckpt和inpainting.ckpt,放到这个repo的根目录下 3. 申请百度翻译API,把你的appid和密钥存到key.py里 4. 运行`python translate_demo.py --image <图片文件路径> [--use-inpainting] [--use-cuda]`,结果会存放到result文件夹里。请加上`--use-inpainting`使用图像修补,请加上`--use-cuda`使用GPU。 # 只是初步版本,我们需要您的帮助完善 @@ -38,48 +45,3 @@ Note this may not work sometimes due to stupid google gcp kept restarting my ins ![Original](original2.jpg "https://twitter.com/mmd_96yuki/status/1320122899005460481")|![Output](result2.png) ![Original](original3.jpg "https://twitter.com/_taroshin_/status/1231099378779082754")|![Output](result3.png) ![Original](original4.jpg "https://amagi.fanbox.cc/posts/1904941")|![Output](result4.png) -# Citation -``` -@inproceedings{baek2019character, - title={Character region awareness for text detection}, - author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk}, - booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, - pages={9365--9374}, - year={2019} -} -@article{hinami2020towards, - title={Towards Fully Automated Manga Translation}, - author={Hinami, Ryota and Ishiwatari, Shonosuke and Yasuda, Kazuhiko and Matsui, Yusuke}, - journal={arXiv preprint arXiv:2012.14271}, - year={2020} -} -@article{oord2017neural, - title={Neural discrete representation learning}, - author={Oord, Aaron van den and Vinyals, Oriol and Kavukcuoglu, Koray}, - journal={arXiv preprint arXiv:1711.00937}, - year={2017} -} -@article{uddin2020global, - title={Global and Local Attention-Based Free-Form Image Inpainting}, - author={Uddin, SM and Jung, Yong Ju}, - journal={Sensors}, - volume={20}, - number={11}, - pages={3204}, - year={2020}, - publisher={Multidisciplinary Digital Publishing Institute} -} -@article{brock2021characterizing, - title={Characterizing signal propagation to close the performance gap in unnormalized ResNets}, - author={Brock, Andrew and De, Soham and Smith, Samuel L}, - journal={arXiv preprint arXiv:2101.08692}, - year={2021} -} -@inproceedings{fujimoto2016manga109, - title={Manga109 dataset and creation of metadata}, - author={Fujimoto, Azuma and Ogawa, Toru and Yamamoto, Kazuyoshi and Matsui, Yusuke and Yamasaki, Toshihiko and Aizawa, Kiyoharu}, - booktitle={Proceedings of the 1st international workshop on comics analysis, processing and understanding}, - pages={1--5}, - year={2016} -} -``` \ No newline at end of file diff --git a/README_EN.md b/README_EN.md index b320eeb3a..ece5e1403 100644 --- a/README_EN.md +++ b/README_EN.md @@ -1,8 +1,17 @@ # Online Demo https://touhou.ai/imgtrans/ Note this may not work sometimes due to stupid google gcp kept restarting my instance. In that case you can wait for me to restart the service, which may take up to 24 hrs. -# New model delayed -New model delayed due to poor result, I am fixing it however there is no guarantee it will be out this week. +# Changelogs +### 2021-05-06 +1. Text detection model is now based on DBNet with ResNet101 backbone +2. OCR model is now deeper +3. Default detection resolution has been increased to 2048 from 1536 + +Note this version is slightly better at handling English texts, other than that it is worse in every other ways +### 2021-03-04 +1. Added inpainting model +### 2021-02-17 +1. First version launched # Translate texts in manga/images Some manga/images will never be translated, therefore this project is born, \ Primarily designed for translating Japanese text, but also support Chinese and sometimes English \ @@ -11,7 +20,7 @@ Successor to https://github.com/PatchyVideo/MMDOCR-HighPerformance # How to use 1. Clone this repo -2. [Download](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2)ocr.ckpt、detect.ckpt and inpainting.ckpt,put them in the root directory of this repo +2. [Download](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2.1)ocr.ckpt、detect.ckpt and inpainting.ckpt,put them in the root directory of this repo 3. Apply for baidu translate API, put ypur appid and key in `key.py` 4. Run`python translate_demo.py --image [--use-inpainting] [--use-cuda]`,result can be found in `result/`. Add `--use-inpainting` to enable inpainting, Add `--use-cuda` to use CUDA. # This is a hobby project, you are welcome to contribute @@ -32,48 +41,3 @@ Original | Translated ![Original](original2.jpg "https://twitter.com/mmd_96yuki/status/1320122899005460481")|![Output](result2.png) ![Original](original3.jpg "https://twitter.com/_taroshin_/status/1231099378779082754")|![Output](result3.png) ![Original](original4.jpg "https://amagi.fanbox.cc/posts/1904941")|![Output](result4.png) -# Citation -``` -@inproceedings{baek2019character, - title={Character region awareness for text detection}, - author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk}, - booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, - pages={9365--9374}, - year={2019} -} -@article{hinami2020towards, - title={Towards Fully Automated Manga Translation}, - author={Hinami, Ryota and Ishiwatari, Shonosuke and Yasuda, Kazuhiko and Matsui, Yusuke}, - journal={arXiv preprint arXiv:2012.14271}, - year={2020} -} -@article{oord2017neural, - title={Neural discrete representation learning}, - author={Oord, Aaron van den and Vinyals, Oriol and Kavukcuoglu, Koray}, - journal={arXiv preprint arXiv:1711.00937}, - year={2017} -} -@article{uddin2020global, - title={Global and Local Attention-Based Free-Form Image Inpainting}, - author={Uddin, SM and Jung, Yong Ju}, - journal={Sensors}, - volume={20}, - number={11}, - pages={3204}, - year={2020}, - publisher={Multidisciplinary Digital Publishing Institute} -} -@article{brock2021characterizing, - title={Characterizing signal propagation to close the performance gap in unnormalized ResNets}, - author={Brock, Andrew and De, Soham and Smith, Samuel L}, - journal={arXiv preprint arXiv:2101.08692}, - year={2021} -} -@inproceedings{fujimoto2016manga109, - title={Manga109 dataset and creation of metadata}, - author={Fujimoto, Azuma and Ogawa, Toru and Yamamoto, Kazuyoshi and Matsui, Yusuke and Yamasaki, Toshihiko and Aizawa, Kiyoharu}, - booktitle={Proceedings of the 1st international workshop on comics analysis, processing and understanding}, - pages={1--5}, - year={2016} -} -``` \ No newline at end of file diff --git a/dbnet_utils.py b/dbnet_utils.py new file mode 100644 index 000000000..69853c7b7 --- /dev/null +++ b/dbnet_utils.py @@ -0,0 +1,179 @@ + +import pyclipper +import cv2 +import numpy as np +from shapely.geometry import Polygon + +class SegDetectorRepresenter(): + def __init__(self, thresh=0.7, box_thresh=0.8, max_candidates=1000, unclip_ratio=1.8): + self.min_size = 3 + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + + def __call__(self, batch, pred, is_output_polygon=False): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + pred: + binary: text region segmentation map, with shape (N, H, W) + thresh: [if exists] thresh hold prediction with shape (N, H, W) + thresh_binary: [if exists] binarized with threshhold, (N, H, W) + ''' + pred = pred[:, 0, :, :] + segmentation = self.binarize(pred) + boxes_batch = [] + scores_batch = [] + for batch_index in range(pred.size(0)): + height, width = batch['shape'][batch_index] + if is_output_polygon: + boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) + else: + boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) + boxes_batch.append(boxes) + scores_batch.append(scores) + return boxes_batch, scores_batch + + def binarize(self, pred): + return pred > self.thresh + + def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap.cpu().numpy() # The first channel + pred = pred.cpu().detach().numpy() + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:self.max_candidates]: + epsilon = 0.005 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + # _, sside = self.get_mini_boxes(contour) + # if sside < self.min_size: + # continue + score = self.box_score_fast(pred, contour.squeeze(1)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, unclip_ratio=self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap.cpu().numpy() # The first channel + pred = pred.cpu().detach().numpy() + height, width = bitmap.shape + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + num_contours = min(len(contours), self.max_candidates) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + scores = np.zeros((num_contours,), dtype=np.float32) + + for index in range(num_contours): + contour = contours[index].squeeze(1) + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, contour) + if self.box_thresh > score: + continue + + box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + startidx = box.sum(axis=1).argmin() + box = np.roll(box, 4-startidx, 0) + box = np.array(box) + boxes[index, :, :] = box.astype(np.int16) + scores[index] = score + return boxes, scores + + def unclip(self, box, unclip_ratio=1.8): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] diff --git a/model_ocr.py b/model_ocr.py index 600fcf205..41f87c5f5 100755 --- a/model_ocr.py +++ b/model_ocr.py @@ -155,7 +155,7 @@ class ResNet_FeatureExtractor(nn.Module): def __init__(self, input_channel, output_channel=128): super(ResNet_FeatureExtractor, self).__init__() - self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [3, 4, 6, 4]) + self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [3, 5, 7, 5]) def forward(self, input): return self.ConvNet(input) @@ -299,10 +299,10 @@ def __init__(self, dictionary, max_len): self.dictionary = dictionary self.dict_size = len(dictionary) self.backbone = ResNet_FeatureExtractor(3, 512) - encoder = nn.TransformerEncoderLayer(512, 8, dropout = 0.0) - decoder = nn.TransformerDecoderLayer(512, 8, dropout = 0.0) - self.encoders = nn.TransformerEncoder(encoder, 3) - self.decoders = nn.TransformerDecoder(decoder, 3) + encoder = nn.TransformerEncoderLayer(512, 4, dropout = 0.0) + decoder = nn.TransformerDecoderLayer(512, 4, dropout = 0.0) + self.encoders = nn.TransformerEncoder(encoder, 2) + self.decoders = nn.TransformerDecoder(decoder, 2) self.pe = PositionalEncoding(512, max_len = max_len) self.embd = nn.Embedding(self.dict_size, 512) self.pred = nn.Sequential(nn.Dropout(0.1), nn.Linear(512, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.dict_size)) diff --git a/translate_demo.py b/translate_demo.py index c7fd0adfb..d6b5d4395 100755 --- a/translate_demo.py +++ b/translate_demo.py @@ -3,7 +3,7 @@ from typing import List from networkx.algorithms.distance_measures import center import torch -from CRAFT_resnet34 import CRAFT_net +from DBNet_resnet101 import TextDetection from model_ocr import OCR from inpainting_model import InpaintingVanilla import einops @@ -12,6 +12,7 @@ import cv2 import numpy as np import craft_utils +import dbnet_utils import itertools import networkx as nx import math @@ -19,7 +20,7 @@ parser = argparse.ArgumentParser(description='Generate text bboxes given a image file') parser.add_argument('--image', default='', type=str, help='Image file') -parser.add_argument('--size', default=1536, type=int, help='image square size') +parser.add_argument('--size', default=2048, type=int, help='image square size') parser.add_argument('--use-inpainting', action='store_true', help='turn on/off inpainting') parser.add_argument('--use-cuda', action='store_true', help='turn on/off cuda') parser.add_argument('--inpainting-size', default=768, type=int, help='size of image used for inpainting (too large will result in OOM)') @@ -31,7 +32,7 @@ import unicodedata -TEXT_EXT_RATIO = 0.1 +TEXT_EXT_RATIO = 0.01 class BBox(object) : def __init__(self, x: int, y: int, w: int, h: int, text: str, prob: float, fg_r: int = 0, fg_g: int = 0, fg_b: int = 0, bg_r: int = 0, bg_g: int = 0, bg_b: int = 0) : @@ -102,7 +103,7 @@ def bbox_direction(x1, y1, w1, h1, ratio = 2.2) : else : return 'v' -def can_merge_textline(x1, y1, w1, h1, x2, y2, w2, h2, ratio = 2.2, char_diff_ratio = 0.7, char_gap_tolerance = 1.3) : +def can_merge_textline(x1, y1, w1, h1, x2, y2, w2, h2, ratio = 2.2, char_diff_ratio = 0.5, char_gap_tolerance = 0.5) : char_size = min(h1, h2, w1, w2) if w1 > h1 * ratio and w2 > h2 * ratio : # both horizontal char_size = min(h1, h2) @@ -303,16 +304,16 @@ def merge_bboxes_text_region(bboxes: List[BBox]) : yield merged_box, nodes, majority_dir, fg_r, fg_g, fg_b, bg_r, bg_g, bg_b def run_detect(model, img_np_resized) : + img_np_resized = img_np_resized.astype(np.float32) / 127.5 - 1.0 img = torch.from_numpy(img_np_resized) if args.use_cuda : img = img.cuda() img = einops.rearrange(img, 'h w c -> 1 c h w') with torch.no_grad() : - craft, mask = model(img) - rscore = craft[0, 0, :, :].cpu().numpy() - ascore = craft[0, 1, :, :].cpu().numpy() + db, mask = model(img) + db = db.sigmoid().cpu() mask = mask[0, 0, :, :].cpu().numpy() - return rscore, ascore, (mask * 255.0).astype(np.uint8) + return db, (mask * 255.0).astype(np.uint8) def overlay_image(a, b, wa = 0.7) : return cv2.addWeighted(a, wa, b, 1 - wa, 0) @@ -389,7 +390,7 @@ def map_bbox(ubox) : images = einops.rearrange(images, 'N H W C -> N C H W') ret = ocr_infer_bacth(images, model, widths) for i, (pred_chars_index, prob, fr, fg, fb, br, bg, bb) in enumerate(ret) : - if prob < 0.2 : + if prob < 0.6 : continue fr = (torch.clip(fr.view(-1), 0, 1).mean() * 255).long().item() fg = (torch.clip(fg.view(-1), 0, 1).mean() * 255).long().item() @@ -477,9 +478,9 @@ def load_ocr_model() : return dictionary, model def load_detect_model() : - model = CRAFT_net() + model = TextDetection() sd = torch.load('detect.ckpt', map_location='cpu') - model.load_state_dict(sd['model']) + model.load_state_dict(sd['model'] if 'model' in sd else sd) model.eval() if args.use_cuda : model = model.cuda() @@ -513,16 +514,18 @@ def main() : img_resized, target_ratio, _, pad_w, pad_h = imgproc.resize_aspect_ratio(img, args.size, cv2.INTER_LINEAR, mag_ratio = 1) img_to_overlay = np.copy(img_resized) ratio_h = ratio_w = 1 / target_ratio - img_resized = imgproc.normalizeMeanVariance(img_resized) print(f'Detection resolution: {img_resized.shape[1]}x{img_resized.shape[0]}') print(' -- Running text detection') - rscore, ascore, mask = run_detect(model_detect, img_resized) - overlay = imgproc.cvt2HeatmapImg(rscore + ascore) - boxes, polys = craft_utils.getDetBoxes(rscore, ascore, args.text_threshold, args.link_threshold, args.low_text, False) - boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h, ratio_net = 2) - polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2) - for k in range(len(polys)): - if polys[k] is None: polys[k] = boxes[k] + db, mask = run_detect(model_detect, img_resized) + overlay = imgproc.cvt2HeatmapImg(db[0, 0, :, :].numpy()) + det = dbnet_utils.SegDetectorRepresenter() + boxes, scores = det({'shape':[(img_resized.shape[0], img_resized.shape[1])]}, db) + boxes, scores = boxes[0], scores[0] + idx = boxes.reshape(boxes.shape[0], -1).sum(axis=1) > 0 + polys, _ = boxes[idx], scores[idx] + polys = polys.astype(np.float64) + polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 1) + polys = polys.astype(np.int16) # merge textlines polys = merge_bboxes(polys, can_merge_textline) for [tl, tr, br, bl] in polys : @@ -595,15 +598,18 @@ def main() : # translate text region texts texts = '\n'.join([r.text for r in text_regions]) trans_ret = baidu_translator.translate('ja', 'zh-CN', texts) - translated_sentences = [] - batch = len(text_regions) - if len(trans_ret) < batch : - translated_sentences.extend(trans_ret) - translated_sentences.extend([''] * (batch - len(trans_ret))) - elif len(trans_ret) > batch : - translated_sentences.extend(trans_ret[:batch]) + if trans_ret : + translated_sentences = [] + batch = len(text_regions) + if len(trans_ret) < batch : + translated_sentences.extend(trans_ret) + translated_sentences.extend([''] * (batch - len(trans_ret))) + elif len(trans_ret) > batch : + translated_sentences.extend(trans_ret[:batch]) + else : + translated_sentences.extend(trans_ret) else : - translated_sentences.extend(trans_ret) + translated_sentences = texts print(' -- Rendering translated text') # render translated texts img_canvas = np.copy(img_inpainted) @@ -622,8 +628,8 @@ def main() : text_render.put_text_vertical(img_canvas, trans_text, len(region.textline_indices), region.x, region.y, region.w, region.h, fg, None) print(' -- Saving results') - cv2.imwrite('result/rs.png', imgproc.cvt2HeatmapImg(rscore)) - cv2.imwrite('result/as.png', imgproc.cvt2HeatmapImg(ascore)) + result_db = db[0, 0, :, :].numpy() + cv2.imwrite('result/db.png', imgproc.cvt2HeatmapImg(result_db)) cv2.imwrite('result/textline.png', overlay) cv2.imwrite('result/bbox.png', img_bbox) cv2.imwrite('result/bbox_unfiltered.png', img_bbox_all)