Skip to content

Commit

Permalink
Merge pull request #35 from dmMaze/new_detection_models
Browse files Browse the repository at this point in the history
new text detect models
  • Loading branch information
zyddnys authored Jan 24, 2022
2 parents 8cbbebe + b3c6d2e commit 44e63cd
Show file tree
Hide file tree
Showing 16 changed files with 3,040 additions and 74 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
result
*.ckpt
*.pt
.vscode
__pycache__
ocrs
88 changes: 56 additions & 32 deletions ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from collections import Counter
import itertools
from typing import List, Tuple
from typing import List, Tuple, Union
from utils import Quadrilateral, quadrilateral_can_merge_region
import torch
import cv2
Expand All @@ -12,6 +12,8 @@
from .model_32px import OCR as OCR_32px
from .model_48px import OCR as OCR_48px

from textblockdetector.textblock import TextBlock

MODEL_32PX = None

def load_model(dictionary, cuda: bool, model_name: str = '32px') :
Expand All @@ -36,11 +38,16 @@ def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]

def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Quadrilateral, str]], max_chunk_size = 16, verbose: bool = False) :
def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Union[Quadrilateral, TextBlock], str]], max_chunk_size = 16, verbose: bool = False) :
text_height = 32
regions = [q.get_transformed_region(img, d, text_height) for q, d in quadrilaterals]
out_regions = []
perm = sorted(range(len(regions)), key = lambda x: regions[x].shape[1])

perm = range(len(regions))
if len(quadrilaterals) > 0:
if isinstance(quadrilaterals[0][0], Quadrilateral):
perm = sorted(range(len(regions)), key = lambda x: regions[x].shape[1])

ix = 0
for indices in chunks(perm, max_chunk_size) :
N = len(indices)
Expand Down Expand Up @@ -83,39 +90,56 @@ def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Quadril
txt = ''.join(seq)
print(prob, txt, f'fg: ({fr}, {fg}, {fb})', f'bg: ({br}, {bg}, {bb})')
cur_region = quadrilaterals[indices[i]][0]
cur_region.text = txt
cur_region.prob = prob
cur_region.fg_r = fr
cur_region.fg_g = fg
cur_region.fg_b = fb
cur_region.bg_r = br
cur_region.bg_g = bg
cur_region.bg_b = bb
if isinstance(cur_region, Quadrilateral):
cur_region.text = txt
cur_region.prob = prob
cur_region.fg_r = fr
cur_region.fg_g = fg
cur_region.fg_b = fb
cur_region.bg_r = br
cur_region.bg_g = bg
cur_region.bg_b = bb
else:
cur_region.text.append(txt)
cur_region.fg_r += fr
cur_region.fg_g += fg
cur_region.fg_b += fb
cur_region.bg_r += br
cur_region.bg_g += bg
cur_region.bg_b += bb

out_regions.append(cur_region)
return out_regions

def generate_text_direction(bboxes: List[Quadrilateral]) :
G = nx.Graph()
for i, box in enumerate(bboxes) :
G.add_node(i, box = box)
for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2) :
if quadrilateral_can_merge_region(ubox, vbox) :
G.add_edge(u, v)
for node_set in nx.algorithms.components.connected_components(G) :
nodes = list(node_set)
# majority vote for direction
dirs = [box.direction for box in [bboxes[i] for i in nodes]]
majority_dir = Counter(dirs).most_common(1)[0][0]
# sort
if majority_dir == 'h' :
nodes = sorted(nodes, key = lambda x: bboxes[x].aabb.y + bboxes[x].aabb.h // 2)
elif majority_dir == 'v' :
nodes = sorted(nodes, key = lambda x: -(bboxes[x].aabb.x + bboxes[x].aabb.w))
# yield overall bbox and sorted indices
for node in nodes :
yield bboxes[node], majority_dir
def generate_text_direction(bboxes: List[Union[Quadrilateral, TextBlock]]) :
if len(bboxes) > 0:
if isinstance(bboxes[0], TextBlock):
for blk in bboxes:
majority_dir = 'v' if blk.vertical else 'h'
for line_idx in range(len(blk.lines)):
yield blk, line_idx
else:
G = nx.Graph()
for i, box in enumerate(bboxes) :
G.add_node(i, box = box)
for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2) :
if quadrilateral_can_merge_region(ubox, vbox) :
G.add_edge(u, v)
for node_set in nx.algorithms.components.connected_components(G) :
nodes = list(node_set)
# majority vote for direction
dirs = [box.direction for box in [bboxes[i] for i in nodes]]
majority_dir = Counter(dirs).most_common(1)[0][0]
# sort
if majority_dir == 'h' :
nodes = sorted(nodes, key = lambda x: bboxes[x].aabb.y + bboxes[x].aabb.h // 2)
elif majority_dir == 'v' :
nodes = sorted(nodes, key = lambda x: -(bboxes[x].aabb.x + bboxes[x].aabb.w))
# yield overall bbox and sorted indices
for node in nodes :
yield bboxes[node], majority_dir

async def dispatch(img: np.ndarray, textlines: List[Quadrilateral], cuda: bool, args: dict, model_name: str = '32px', batch_size: int = 16, verbose: bool = False) -> List[Quadrilateral] :
async def dispatch(img: np.ndarray, textlines: List[Union[Quadrilateral, TextBlock]], cuda: bool, args: dict, model_name: str = '32px', batch_size: int = 16, verbose: bool = False) -> List[Quadrilateral] :
print(' -- Running OCR')
if model_name == '32px' :
return run_ocr_32px(img, cuda, list(generate_text_direction(textlines)), batch_size)
109 changes: 109 additions & 0 deletions text_rendering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from utils import findNextPowerOf2

from . import text_render
from textblockdetector.textblock import TextBlock

async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], force_horizontal: bool) -> np.ndarray :
for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) :
Expand Down Expand Up @@ -110,3 +111,111 @@ async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translate
mask_region = rgba_region[:, :, 3: 4].astype(np.float32) / 255.0
img_canvas = np.clip((img_canvas.astype(np.float32) * (1 - mask_region) + canvas_region.astype(np.float32) * mask_region), 0, 255).astype(np.uint8)
return img_canvas


async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], force_horizontal: bool) -> np.ndarray :
for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) :
print(f'text: {region.get_text()} \n trans: {trans_text}')
if not trans_text :
continue
if force_horizontal :
majority_dir = 'h'
else:
majority_dir = 'v' if region.vertical else 'h'

fg, bg = region.get_font_colors()
font_size = region.font_size
font_size = round(font_size)

region_x, region_y, region_w, region_h = region.xyxy
region_w -= region_x
region_h -= region_y

textlines = []
for ii, text in enumerate(region.text):
textlines.append(Quadrilateral(np.array(region.lines[ii]), text, 1, region.fg_r, region.fg_g, region.fg_b, region.bg_r, region.bg_g, region.bg_b))
# region_aabb = region.aabb
# print(region_aabb.x, region_aabb.y, region_aabb.w, region_aabb.h)

# round font_size to fixed powers of 2, so later LRU cache can work
font_size_enlarged = findNextPowerOf2(font_size) * text_mag_ratio
enlarge_ratio = font_size_enlarged / font_size
font_size = font_size_enlarged
while True :
enlarged_w = round(enlarge_ratio * region_w)
enlarged_h = round(enlarge_ratio * region_h)
rows = enlarged_h // (font_size * 1.3)
cols = enlarged_w // (font_size * 1.3)
if rows * cols < len(trans_text) :
enlarge_ratio *= 1.1
continue
break
print('font_size:', font_size)

tmp_canvas = np.ones((enlarged_h * 2, enlarged_w * 2, 3), dtype = np.uint8) * 127
tmp_mask = np.zeros((enlarged_h * 2, enlarged_w * 2), dtype = np.uint16)

if majority_dir == 'h' :
text_render.put_text_horizontal(
font_size,
enlarge_ratio * 1.0,
tmp_canvas,
tmp_mask,
trans_text,
len(region.lines),
textlines,
enlarged_w // 2,
enlarged_h // 2,
enlarged_w,
enlarged_h,
fg,
bg
)
else :
text_render.put_text_vertical(
font_size,
enlarge_ratio * 1.0,
tmp_canvas,
tmp_mask,
trans_text,
len(region.lines),
textlines,
enlarged_w // 2,
enlarged_h // 2,
enlarged_w,
enlarged_h,
fg,
bg
)

tmp_mask = np.clip(tmp_mask, 0, 255).astype(np.uint8)
x, y, w, h = cv2.boundingRect(tmp_mask)
r_prime = w / h

r = region.aspect_ratio()
if majority_dir != 'v':
r = 1 / r

w_ext = 0
h_ext = 0
if r_prime > r :
h_ext = w / (2 * r) - h / 2
else :
w_ext = (h * r - w) / 2
region_ext = round(min(w, h) * 0.05)
h_ext += region_ext
w_ext += region_ext
src_pts = np.array([[x - w_ext, y - h_ext], [x + w + w_ext, y - h_ext], [x + w + w_ext, y + h + h_ext], [x - w_ext, y + h + h_ext]]).astype(np.float32)
src_pts[:, 0] = np.clip(np.round(src_pts[:, 0]), 0, enlarged_w * 2)
src_pts[:, 1] = np.clip(np.round(src_pts[:, 1]), 0, enlarged_h * 2)

dst_pts = region.min_rect()
if majority_dir == 'v':
dst_pts = dst_pts[:, [3, 0, 1, 2]]
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
tmp_rgba = np.concatenate([tmp_canvas, tmp_mask[:, :, None]], axis = -1).astype(np.float32)
rgba_region = np.clip(cv2.warpPerspective(tmp_rgba, M, (img_canvas.shape[1], img_canvas.shape[0]), flags = cv2.INTER_LINEAR, borderMode = cv2.BORDER_CONSTANT, borderValue = 0), 0, 255)
canvas_region = rgba_region[:, :, 0: 3]
mask_region = rgba_region[:, :, 3: 4].astype(np.float32) / 255.0
img_canvas = np.clip((img_canvas.astype(np.float32) * (1 - mask_region) + canvas_region.astype(np.float32) * mask_region), 0, 255).astype(np.uint8)
return img_canvas
126 changes: 126 additions & 0 deletions textblockdetector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import json
from .basemodel import TextDetBase
import os.path as osp
from tqdm import tqdm
import numpy as np
import cv2
import torch
from pathlib import Path
import torch
import onnxruntime
from .utils.yolov5_utils import non_max_suppression
from .utils.db_utils import SegDetectorRepresenter
from .utils.io_utils import imread, imwrite, find_all_imgs, NumpyEncoder
from .utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings
from .textblock import TextBlock, group_output
from .textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION

def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
if bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64)
img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255
if to_tensor:
img_in = torch.from_numpy(img_in).to(device)
if half:
img_in = img_in.half()
return img_in, ratio, int(dw), int(dh)

def postprocess_mask(img: torch.Tensor, thresh=None):
# img = img.permute(1, 2, 0)
if thresh is not None:
img = img > thresh
img = img * 255
if img.device != 'cpu':
img = img.detach_().cpu()
img = img.numpy().astype(np.uint8)
return img

def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None):
det = non_max_suppression(det, conf_thresh, nms_thresh)[0]
# bbox = det[..., 0:4]
if det.device != 'cpu':
det = det.detach_().cpu().numpy()
det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0]
det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1]
if sort_func is not None:
det = sort_func(det)

blines = det[..., 0:4].astype(np.int32)
confs = np.round(det[..., 4], 3)
cls = det[..., 5].astype(np.int32)
return blines, cls, confs

class TextDetector:
lang_list = ['eng', 'ja', 'unknown']
langcls2idx = {'eng': 0, 'ja': 1, 'unknown': 2}

def __init__(self, model_path, input_size=1152, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4, mask_thresh=0.3, act='leaky', backend='torch') :
super(TextDetector, self).__init__()
cuda = device == 'cuda'
self.backend = backend
if self.backend == 'torch':
self.net = TextDetBase(model_path, device=device, act=act)
else:
# TODO: OPENCV ONNX INFERENCE
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
self.session = onnxruntime.InferenceSession(model_path, providers=providers)
if isinstance(input_size, int):
input_size = (input_size, input_size)
self.input_size = input_size
self.device = device
self.half = half
self.conf_thresh = conf_thresh
self.nms_thresh = nms_thresh
self.seg_rep = SegDetectorRepresenter(thresh=0.3)

def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False, bgr2rgb=True):
img_in, ratio, dw, dh = preprocess_img(img, input_size=self.input_size, device=self.device, half=self.half, bgr2rgb=bgr2rgb)

im_h, im_w = img.shape[:2]
with torch.no_grad():
blks, mask, lines_map = self.net(img_in)

resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh))
blks = postprocess_yolo(blks[0], self.conf_thresh, self.nms_thresh, resize_ratio)
mask = postprocess_mask(mask.squeeze_())
lines, scores = self.seg_rep(self.input_size, lines_map)
box_thresh = 0.6
idx = np.where(scores[0] > box_thresh)
lines, scores = lines[0][idx], scores[0][idx]

# map output to input img
mask = mask[: mask.shape[0]-dh, : mask.shape[1]-dw]
mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
if lines.size == 0 :
lines = []
else :
lines = lines.astype(np.float64)
lines[..., 0] *= resize_ratio[0]
lines[..., 1] *= resize_ratio[1]
lines = lines.astype(np.int32)
blk_list = group_output(blks, lines, im_w, im_h, mask)
mask_refined = refine_mask(img, mask, blk_list, refine_mode=refine_mode)
if keep_undetected_mask:
mask_refined = refine_undetected_mask(img, mask, mask_refined, blk_list, refine_mode=refine_mode)

return mask, mask_refined, blk_list

def cuda(self):
self.net.to('cuda')

DEFAULT_MODEL = None
def load_model(cuda: bool):
global DEFAULT_MODEL
device = 'cuda' if cuda else 'cpu'
model = TextDetector(model_path='comictextdetector.pt', device=device, act='leaky')
if cuda :
model.cuda()
DEFAULT_MODEL = model

async def dispatch(img: np.ndarray, cuda: bool):
global DEFAULT_MODEL
if DEFAULT_MODEL is None :
load_model(cuda)
return DEFAULT_MODEL(img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False, bgr2rgb=False)
Loading

0 comments on commit 44e63cd

Please sign in to comment.