Skip to content

Commit

Permalink
add youdao translate & change linebreak to LF only
Browse files Browse the repository at this point in the history
  • Loading branch information
zyddnys committed May 12, 2021
1 parent 1e4954e commit cfb6033
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 234 deletions.
306 changes: 153 additions & 153 deletions CRAFT_resnet34.py
Original file line number Diff line number Diff line change
@@ -1,153 +1,153 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torchvision.models import resnet34

import einops
import math

class ImageMultiheadSelfAttention(nn.Module) :
def __init__(self, planes):
super(ImageMultiheadSelfAttention, self).__init__()
self.attn = nn.MultiheadAttention(planes, 4)
def forward(self, x) :
res = x
n, c, h, w = x.shape
x = einops.rearrange(x, 'n c h w -> (h w) n c')
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w)
return res + x

class double_conv(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
super(double_conv, self).__init__()
self.planes = planes
# down = None
# if stride > 1 :
# down = nn.Sequential(
# nn.AvgPool2d(2, 2),
# nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
# )
self.down = None
if stride > 1 :
self.down = nn.AvgPool2d(2,stride=2)
self.conv = nn.Sequential(
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
#Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)

def forward(self, x):
if self.down is not None :
x = self.down(x)
x = self.conv(x)
return x

class CRAFT_net(nn.Module) :
def __init__(self) :
super(CRAFT_net, self).__init__()
self.backbone = resnet34()

self.conv_rs = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.conv_as = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.conv_mask = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.down_conv1 = double_conv(0, 512, 512, 2)
self.down_conv2 = double_conv(0, 512, 512, 2)
self.down_conv3 = double_conv(0, 512, 512, 2)

self.upconv1 = double_conv(0, 512, 256)
self.upconv2 = double_conv(256, 512, 256)
self.upconv3 = double_conv(256, 512, 256)
self.upconv4 = double_conv(256, 512, 256, planes = 128)
self.upconv5 = double_conv(256, 256, 128, planes = 64)
self.upconv6 = double_conv(128, 128, 64, planes = 32)
self.upconv7 = double_conv(64, 64, 64, planes = 16)

def forward_train(self, x) :
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # 64@384

h4 = self.backbone.layer1(x) # 64@384
h8 = self.backbone.layer2(h4) # 128@192
h16 = self.backbone.layer3(h8) # 256@96
h32 = self.backbone.layer4(h16) # 512@48
h64 = self.down_conv1(h32) # 512@24
h128 = self.down_conv2(h64) # 512@12
h256 = self.down_conv3(h128) # 512@6

up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768

ascore = self.conv_as(up4)
rscore = self.conv_rs(up4)

return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)

def forward(self, x) :
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # 64@384

h4 = self.backbone.layer1(x) # 64@384
h8 = self.backbone.layer2(h4) # 128@192
h16 = self.backbone.layer3(h8) # 256@96
h32 = self.backbone.layer4(h16) # 512@48
h64 = self.down_conv1(h32) # 512@24
h128 = self.down_conv2(h64) # 512@12
h256 = self.down_conv3(h128) # 512@6

up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768

ascore = self.conv_as(up4)
rscore = self.conv_rs(up4)

return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)

if __name__ == '__main__' :
net = CRAFT_net().cuda()
img = torch.randn(2, 3, 1536, 1536).cuda()
print(net.forward_train(img)[0].shape)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torchvision.models import resnet34

import einops
import math

class ImageMultiheadSelfAttention(nn.Module) :
def __init__(self, planes):
super(ImageMultiheadSelfAttention, self).__init__()
self.attn = nn.MultiheadAttention(planes, 4)
def forward(self, x) :
res = x
n, c, h, w = x.shape
x = einops.rearrange(x, 'n c h w -> (h w) n c')
x = self.attn(x, x, x)[0]
x = einops.rearrange(x, '(h w) n c -> n c h w', n = n, c = c, h = h, w = w)
return res + x

class double_conv(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch, stride = 1, planes = 256):
super(double_conv, self).__init__()
self.planes = planes
# down = None
# if stride > 1 :
# down = nn.Sequential(
# nn.AvgPool2d(2, 2),
# nn.Conv2d(in_ch + mid_ch, self.planes * Bottleneck.expansion, kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(self.planes * Bottleneck.expansion)
# )
self.down = None
if stride > 1 :
self.down = nn.AvgPool2d(2,stride=2)
self.conv = nn.Sequential(
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=3, padding=1, stride = 1, bias=False),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
#Bottleneck(mid_ch, self.planes, stride, down, 2, 1, avd = True, norm_layer = nn.BatchNorm2d),
nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride = 1, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)

def forward(self, x):
if self.down is not None :
x = self.down(x)
x = self.conv(x)
return x

class CRAFT_net(nn.Module) :
def __init__(self) :
super(CRAFT_net, self).__init__()
self.backbone = resnet34()

self.conv_rs = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.conv_as = nn.Sequential(
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.conv_mask = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)

self.down_conv1 = double_conv(0, 512, 512, 2)
self.down_conv2 = double_conv(0, 512, 512, 2)
self.down_conv3 = double_conv(0, 512, 512, 2)

self.upconv1 = double_conv(0, 512, 256)
self.upconv2 = double_conv(256, 512, 256)
self.upconv3 = double_conv(256, 512, 256)
self.upconv4 = double_conv(256, 512, 256, planes = 128)
self.upconv5 = double_conv(256, 256, 128, planes = 64)
self.upconv6 = double_conv(128, 128, 64, planes = 32)
self.upconv7 = double_conv(64, 64, 64, planes = 16)

def forward_train(self, x) :
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # 64@384

h4 = self.backbone.layer1(x) # 64@384
h8 = self.backbone.layer2(h4) # 128@192
h16 = self.backbone.layer3(h8) # 256@96
h32 = self.backbone.layer4(h16) # 512@48
h64 = self.down_conv1(h32) # 512@24
h128 = self.down_conv2(h64) # 512@12
h256 = self.down_conv3(h128) # 512@6

up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768

ascore = self.conv_as(up4)
rscore = self.conv_rs(up4)

return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)

def forward(self, x) :
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x) # 64@384

h4 = self.backbone.layer1(x) # 64@384
h8 = self.backbone.layer2(h4) # 128@192
h16 = self.backbone.layer3(h8) # 256@96
h32 = self.backbone.layer4(h16) # 512@48
h64 = self.down_conv1(h32) # 512@24
h128 = self.down_conv2(h64) # 512@12
h256 = self.down_conv3(h128) # 512@6

up256 = F.interpolate(self.upconv1(h256), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 512@12
up128 = F.interpolate(self.upconv2(torch.cat([up256, h128], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) #51264@24
up64 = F.interpolate(self.upconv3(torch.cat([up128, h64], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@48
up32 = F.interpolate(self.upconv4(torch.cat([up64, h32], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 256@96
up16 = F.interpolate(self.upconv5(torch.cat([up32, h16], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 128@192
up8 = F.interpolate(self.upconv6(torch.cat([up16, h8], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@384
up4 = F.interpolate(self.upconv7(torch.cat([up8, h4], dim = 1)), scale_factor = (2, 2), mode = 'bilinear', align_corners = False) # 64@768

ascore = self.conv_as(up4)
rscore = self.conv_rs(up4)

return torch.cat([rscore, ascore], dim = 1), self.conv_mask(up4)

if __name__ == '__main__' :
net = CRAFT_net().cuda()
img = torch.randn(2, 3, 1536, 1536).cuda()
print(net.forward_train(img)[0].shape)
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Note this may not work sometimes due to stupid google gcp kept restarting my ins
# English README
[README_EN.md](README_EN.md)
# Changelogs
### 2021-05-11
1. 增加并默认使用有道翻译
### 2021-05-06
1. 检测模型更新为基于ResNet101的DBNet
2. OCR模型更新更深
Expand All @@ -23,7 +25,7 @@ Note this may not work sometimes due to stupid google gcp kept restarting my ins
# 使用说明
1. clone这个repo
2. [下载](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2.1)ocr.ckpt、detect.ckpt和inpainting.ckpt,放到这个repo的根目录下
3. 申请百度翻译API,把你的appid和密钥存到key.py里
3. 申请有道翻译API,把你的APP_KEY和APP_SECRET存到key.py里
4. 运行`python translate_demo.py --image <图片文件路径> [--use-inpainting] [--use-cuda]`,结果会存放到result文件夹里。请加上`--use-inpainting`使用图像修补,请加上`--use-cuda`使用GPU。
# 只是初步版本,我们需要您的帮助完善
这个项目目前只完成了简单的demo,依旧存在大量不完善的地方,我们需要您的帮助完善这个项目!
Expand Down
4 changes: 3 additions & 1 deletion README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
https://touhou.ai/imgtrans/
Note this may not work sometimes due to stupid google gcp kept restarting my instance. In that case you can wait for me to restart the service, which may take up to 24 hrs.
# Changelogs
### 2021-05-11
1. Add youdao translate and set as default translator
### 2021-05-06
1. Text detection model is now based on DBNet with ResNet101 backbone
2. OCR model is now deeper
Expand All @@ -21,7 +23,7 @@ Successor to https://github.com/PatchyVideo/MMDOCR-HighPerformance
# How to use
1. Clone this repo
2. [Download](https://github.com/zyddnys/manga-image-translator/releases/tag/alpha-v2.2.1)ocr.ckpt、detect.ckpt and inpainting.ckpt,put them in the root directory of this repo
3. Apply for baidu translate API, put ypur appid and key in `key.py`
3. Apply for youdao translate API, put ypur APP_KEY and APP_SECRET in `key.py`
4. Run`python translate_demo.py --image <path_to_image_file> [--use-inpainting] [--use-cuda]`,result can be found in `result/`. Add `--use-inpainting` to enable inpainting, Add `--use-cuda` to use CUDA.
# This is a hobby project, you are welcome to contribute
Currently this only a simple demo, many imperfections exist, we need your support to make this project better!
Expand Down
Loading

0 comments on commit cfb6033

Please sign in to comment.