From 0332c7313bd3a6dbc435c05373053ca3b4b0f778 Mon Sep 17 00:00:00 2001 From: dmMaze Date: Sat, 22 Jan 2022 15:56:25 +0800 Subject: [PATCH 1/3] new text detect models --- .gitignore | 2 + ocr/__init__.py | 67 ++- text_rendering/__init__.py | 112 ++++ textblockdetector/__init__.py | 126 ++++ textblockdetector/basemodel.py | 241 ++++++++ textblockdetector/textblock.py | 375 ++++++++++++ textblockdetector/textmask.py | 170 ++++++ textblockdetector/utils/db_utils.py | 695 +++++++++++++++++++++++ textblockdetector/utils/imgproc_utils.py | 171 ++++++ textblockdetector/utils/io_utils.py | 54 ++ textblockdetector/utils/weight_init.py | 103 ++++ textblockdetector/utils/yolov5_utils.py | 240 ++++++++ textblockdetector/yolov5/common.py | 290 ++++++++++ textblockdetector/yolov5/yolo.py | 311 ++++++++++ translate_demo.py | 107 ++-- translators/__init__.py | 7 +- 16 files changed, 3007 insertions(+), 64 deletions(-) create mode 100644 textblockdetector/__init__.py create mode 100644 textblockdetector/basemodel.py create mode 100644 textblockdetector/textblock.py create mode 100644 textblockdetector/textmask.py create mode 100644 textblockdetector/utils/db_utils.py create mode 100644 textblockdetector/utils/imgproc_utils.py create mode 100644 textblockdetector/utils/io_utils.py create mode 100644 textblockdetector/utils/weight_init.py create mode 100644 textblockdetector/utils/yolov5_utils.py create mode 100644 textblockdetector/yolov5/common.py create mode 100644 textblockdetector/yolov5/yolo.py diff --git a/.gitignore b/.gitignore index 69ae39640..a07961125 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ result *.ckpt +*.pt .vscode +translators __pycache__ ocrs \ No newline at end of file diff --git a/ocr/__init__.py b/ocr/__init__.py index be5db90c2..45be753e6 100644 --- a/ocr/__init__.py +++ b/ocr/__init__.py @@ -1,7 +1,7 @@ from collections import Counter import itertools -from typing import List, Tuple +from typing import List, Tuple, Union from utils import Quadrilateral, quadrilateral_can_merge_region import torch import cv2 @@ -12,6 +12,8 @@ from .model_32px import OCR as OCR_32px from .model_48px import OCR as OCR_48px +from textblockdetector.textblock import TextBlock + MODEL_32PX = None def load_model(dictionary, cuda: bool, model_name: str = '32px') : @@ -36,11 +38,16 @@ def chunks(lst, n): for i in range(0, len(lst), n): yield lst[i:i + n] -def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Quadrilateral, str]], max_chunk_size = 16, verbose: bool = False) : +def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Union[Quadrilateral, TextBlock], str]], max_chunk_size = 16, verbose: bool = False) : text_height = 32 regions = [q.get_transformed_region(img, d, text_height) for q, d in quadrilaterals] out_regions = [] - perm = sorted(range(len(regions)), key = lambda x: regions[x].shape[1]) + + perm = range(len(regions)) + if len(quadrilaterals) > 0: + if isinstance(quadrilaterals[0][0], Quadrilateral): + perm = sorted(range(len(regions)), key = lambda x: regions[x].shape[1]) + ix = 0 for indices in chunks(perm, max_chunk_size) : N = len(indices) @@ -83,7 +90,10 @@ def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Quadril txt = ''.join(seq) print(prob, txt, f'fg: ({fr}, {fg}, {fb})', f'bg: ({br}, {bg}, {bb})') cur_region = quadrilaterals[indices[i]][0] - cur_region.text = txt + if isinstance(cur_region, Quadrilateral): + cur_region.text = txt + else: + cur_region.text.append(txt) cur_region.prob = prob cur_region.fg_r = fr cur_region.fg_g = fg @@ -94,28 +104,35 @@ def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Quadril out_regions.append(cur_region) return out_regions -def generate_text_direction(bboxes: List[Quadrilateral]) : - G = nx.Graph() - for i, box in enumerate(bboxes) : - G.add_node(i, box = box) - for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2) : - if quadrilateral_can_merge_region(ubox, vbox) : - G.add_edge(u, v) - for node_set in nx.algorithms.components.connected_components(G) : - nodes = list(node_set) - # majority vote for direction - dirs = [box.direction for box in [bboxes[i] for i in nodes]] - majority_dir = Counter(dirs).most_common(1)[0][0] - # sort - if majority_dir == 'h' : - nodes = sorted(nodes, key = lambda x: bboxes[x].aabb.y + bboxes[x].aabb.h // 2) - elif majority_dir == 'v' : - nodes = sorted(nodes, key = lambda x: -(bboxes[x].aabb.x + bboxes[x].aabb.w)) - # yield overall bbox and sorted indices - for node in nodes : - yield bboxes[node], majority_dir +def generate_text_direction(bboxes: List[Union[Quadrilateral, TextBlock]]) : + if len(bboxes) > 0: + if isinstance(bboxes[0], TextBlock): + for blk in bboxes: + majority_dir = 'v' if blk.vertical else 'h' + for line_idx in range(len(blk.lines)): + yield blk, line_idx + else: + G = nx.Graph() + for i, box in enumerate(bboxes) : + G.add_node(i, box = box) + for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2) : + if quadrilateral_can_merge_region(ubox, vbox) : + G.add_edge(u, v) + for node_set in nx.algorithms.components.connected_components(G) : + nodes = list(node_set) + # majority vote for direction + dirs = [box.direction for box in [bboxes[i] for i in nodes]] + majority_dir = Counter(dirs).most_common(1)[0][0] + # sort + if majority_dir == 'h' : + nodes = sorted(nodes, key = lambda x: bboxes[x].aabb.y + bboxes[x].aabb.h // 2) + elif majority_dir == 'v' : + nodes = sorted(nodes, key = lambda x: -(bboxes[x].aabb.x + bboxes[x].aabb.w)) + # yield overall bbox and sorted indices + for node in nodes : + yield bboxes[node], majority_dir -async def dispatch(img: np.ndarray, textlines: List[Quadrilateral], cuda: bool, args: dict, model_name: str = '32px', batch_size: int = 16, verbose: bool = False) -> List[Quadrilateral] : +async def dispatch(img: np.ndarray, textlines: List[Union[Quadrilateral, TextBlock]], cuda: bool, args: dict, model_name: str = '32px', batch_size: int = 16, verbose: bool = False) -> List[Quadrilateral] : print(' -- Running OCR') if model_name == '32px' : return run_ocr_32px(img, cuda, list(generate_text_direction(textlines)), batch_size) diff --git a/text_rendering/__init__.py b/text_rendering/__init__.py index 1370f36b7..7f175cc3e 100644 --- a/text_rendering/__init__.py +++ b/text_rendering/__init__.py @@ -6,6 +6,7 @@ from utils import findNextPowerOf2 from . import text_render +from textblockdetector.textblock import TextBlock async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], force_horizontal: bool) -> np.ndarray : for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) : @@ -110,3 +111,114 @@ async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translate mask_region = rgba_region[:, :, 3: 4].astype(np.float32) / 255.0 img_canvas = np.clip((img_canvas.astype(np.float32) * (1 - mask_region) + canvas_region.astype(np.float32) * mask_region), 0, 255).astype(np.uint8) return img_canvas + + +async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], force_horizontal: bool) -> np.ndarray : + for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) : + if not trans_text : + continue + if force_horizontal : + majority_dir = 'h' + else: + majority_dir = 'v' if region.vertical else 'h' + print(region.text) + print(trans_text) + fg = (region.fg_r, region.fg_g, region.fg_b) + bg = (region.bg_r, region.bg_g, region.bg_b) + font_size = region.font_size + font_size = round(font_size) + + region_x, region_y, region_w, region_h = region.xyxy + region_w -= region_x + region_h -= region_y + + textlines = [] + for ii, text in enumerate(region.text): + textlines.append(Quadrilateral(np.array(region.lines[ii]), text, 1, region.fg_r, region.fg_g, region.fg_b, region.bg_r, region.bg_g, region.bg_b)) + # region_aabb = region.aabb + # print(region_aabb.x, region_aabb.y, region_aabb.w, region_aabb.h) + + # round font_size to fixed powers of 2, so later LRU cache can work + font_size_enlarged = findNextPowerOf2(font_size) * text_mag_ratio + enlarge_ratio = font_size_enlarged / font_size + font_size = font_size_enlarged + while True : + enlarged_w = round(enlarge_ratio * region_w) + enlarged_h = round(enlarge_ratio * region_h) + rows = enlarged_h // (font_size * 1.3) + cols = enlarged_w // (font_size * 1.3) + if rows * cols < len(trans_text) : + enlarge_ratio *= 1.1 + continue + break + print('font_size:', font_size) + + tmp_canvas = np.ones((enlarged_h * 2, enlarged_w * 2, 3), dtype = np.uint8) * 127 + tmp_mask = np.zeros((enlarged_h * 2, enlarged_w * 2), dtype = np.uint16) + + if majority_dir == 'h' : + text_render.put_text_horizontal( + font_size, + enlarge_ratio * 1.0, + tmp_canvas, + tmp_mask, + trans_text, + len(region.lines), + textlines, + enlarged_w // 2, + enlarged_h // 2, + enlarged_w, + enlarged_h, + fg, + bg + ) + else : + text_render.put_text_vertical( + font_size, + enlarge_ratio * 1.0, + tmp_canvas, + tmp_mask, + trans_text, + len(region.lines), + textlines, + enlarged_w // 2, + enlarged_h // 2, + enlarged_w, + enlarged_h, + fg, + bg + ) + + tmp_mask = np.clip(tmp_mask, 0, 255).astype(np.uint8) + x, y, w, h = cv2.boundingRect(tmp_mask) + r_prime = w / h + + r = region.aspect_ratio + if majority_dir != 'v': + r = 1 / r + + w_ext = 0 + h_ext = 0 + if r_prime > r : + h_ext = w / (2 * r) - h / 2 + else : + w_ext = (h * r - w) / 2 + region_ext = round(min(w, h) * 0.05) + h_ext += region_ext + w_ext += region_ext + src_pts = np.array([[x - w_ext, y - h_ext], [x + w + w_ext, y - h_ext], [x + w + w_ext, y + h + h_ext], [x - w_ext, y + h + h_ext]]).astype(np.float32) + src_pts[:, 0] = np.clip(np.round(src_pts[:, 0]), 0, enlarged_w * 2) + src_pts[:, 1] = np.clip(np.round(src_pts[:, 1]), 0, enlarged_h * 2) + # dst_pts = region.mini_rect[:, [3, 0, 1, 2]] + if majority_dir == 'v': + dst_pts = region.mini_rect[:, [3, 0, 1, 2]] + else: + dst_pts = region.mini_rect + # dst_pts = region.mini_rect + M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) + tmp_rgba = np.concatenate([tmp_canvas, tmp_mask[:, :, None]], axis = -1).astype(np.float32) + rgba_region = np.clip(cv2.warpPerspective(tmp_rgba, M, (img_canvas.shape[1], img_canvas.shape[0]), flags = cv2.INTER_LINEAR, borderMode = cv2.BORDER_CONSTANT, borderValue = 0), 0, 255) + canvas_region = rgba_region[:, :, 0: 3] + mask_region = rgba_region[:, :, 3: 4].astype(np.float32) / 255.0 + img_canvas = np.clip((img_canvas.astype(np.float32) * (1 - mask_region) + canvas_region.astype(np.float32) * mask_region), 0, 255).astype(np.uint8) + return img_canvas \ No newline at end of file diff --git a/textblockdetector/__init__.py b/textblockdetector/__init__.py new file mode 100644 index 000000000..3d3d413e0 --- /dev/null +++ b/textblockdetector/__init__.py @@ -0,0 +1,126 @@ +import json +from .basemodel import TextDetBase +import os.path as osp +from tqdm import tqdm +import numpy as np +import cv2 +import torch +from pathlib import Path +import torch +import onnxruntime +from .utils.yolov5_utils import non_max_suppression +from .utils.db_utils import SegDetectorRepresenter +from .utils.io_utils import imread, imwrite, find_all_imgs, NumpyEncoder +from .utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings +from .textblock import TextBlock, group_output +from .textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION + +def preprocess_img(img, input_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True): + if bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_in, ratio, (dw, dh) = letterbox(img, new_shape=input_size, auto=False, stride=64) + img_in = img_in.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + img_in = np.array([np.ascontiguousarray(img_in)]).astype(np.float32) / 255 + if to_tensor: + img_in = torch.from_numpy(img_in).to(device) + if half: + img_in = img_in.half() + return img_in, ratio, int(dw), int(dh) + +def postprocess_mask(img: torch.Tensor, thresh=None): + # img = img.permute(1, 2, 0) + if thresh is not None: + img = img > thresh + img = img * 255 + if img.device != 'cpu': + img = img.detach_().cpu() + img = img.numpy().astype(np.uint8) + return img + +def postprocess_yolo(det, conf_thresh, nms_thresh, resize_ratio, sort_func=None): + det = non_max_suppression(det, conf_thresh, nms_thresh)[0] + # bbox = det[..., 0:4] + if det.device != 'cpu': + det = det.detach_().cpu().numpy() + det[..., [0, 2]] = det[..., [0, 2]] * resize_ratio[0] + det[..., [1, 3]] = det[..., [1, 3]] * resize_ratio[1] + if sort_func is not None: + det = sort_func(det) + + blines = det[..., 0:4].astype(np.int32) + confs = np.round(det[..., 4], 3) + cls = det[..., 5].astype(np.int32) + return blines, cls, confs + +class TextDetector: + lang_list = ['eng', 'ja', 'unknown'] + langcls2idx = {'eng': 0, 'ja': 1, 'unknown': 2} + + def __init__(self, model_path, input_size=1152, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4, mask_thresh=0.3, act='leaky', backend='torch') : + super(TextDetector, self).__init__() + cuda = device == 'cuda' + self.backend = backend + if self.backend == 'torch': + self.net = TextDetBase(model_path, device=device, act=act) + else: + # TODO: OPENCV ONNX INFERENCE + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider'] + self.session = onnxruntime.InferenceSession(model_path, providers=providers) + if isinstance(input_size, int): + input_size = (input_size, input_size) + self.input_size = input_size + self.device = device + self.half = half + self.conf_thresh = conf_thresh + self.nms_thresh = nms_thresh + self.seg_rep = SegDetectorRepresenter(thresh=0.3) + + def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False, bgr2rgb=True): + img_in, ratio, dw, dh = preprocess_img(img, input_size=self.input_size, device=self.device, half=self.half, bgr2rgb=bgr2rgb) + + im_h, im_w = img.shape[:2] + with torch.no_grad(): + blks, mask, lines_map = self.net(img_in) + + resize_ratio = (im_w / (self.input_size[0] - dw), im_h / (self.input_size[1] - dh)) + blks = postprocess_yolo(blks[0], self.conf_thresh, self.nms_thresh, resize_ratio) + mask = postprocess_mask(mask.squeeze_()) + lines, scores = self.seg_rep(self.input_size, lines_map) + box_thresh = 0.6 + idx = np.where(scores[0] > box_thresh) + lines, scores = lines[0][idx], scores[0][idx] + + # map output to input img + mask = mask[: mask.shape[0]-dh, : mask.shape[1]-dw] + mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR) + if lines.size == 0 : + lines = [] + else : + lines = lines.astype(np.float64) + lines[..., 0] *= resize_ratio[0] + lines[..., 1] *= resize_ratio[1] + lines = lines.astype(np.int32) + blk_list = group_output(blks, lines, im_w, im_h, mask) + mask_refined = refine_mask(img, mask, blk_list, refine_mode=refine_mode) + if keep_undetected_mask: + mask_refined = refine_undetected_mask(img, mask, mask_refined, blk_list, refine_mode=refine_mode) + + return mask, mask_refined, blk_list + + def cuda(self): + self.net.to('cuda') + +DEFAULT_MODEL = None +def load_model(cuda: bool): + global DEFAULT_MODEL + device = 'cuda' if cuda else 'cpu' + model = TextDetector(model_path='comictextdetector.pt', device=device, act='leaky') + if cuda : + model.cuda() + DEFAULT_MODEL = model + +async def dispatch(img: np.ndarray, cuda: bool): + global DEFAULT_MODEL + if DEFAULT_MODEL is None : + load_model(cuda) + return DEFAULT_MODEL(img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False, bgr2rgb=False) \ No newline at end of file diff --git a/textblockdetector/basemodel.py b/textblockdetector/basemodel.py new file mode 100644 index 000000000..5a134f076 --- /dev/null +++ b/textblockdetector/basemodel.py @@ -0,0 +1,241 @@ +from .yolov5.yolo import Model +import torch +import cv2 +import numpy as np +from .yolov5.yolo import load_yolov5_ckpt +from .utils.yolov5_utils import fuse_conv_and_bn +import glob +import torch.nn as nn +from .utils.weight_init import init_weights +from .yolov5.common import C3, Conv +# from torchsummary import summary +import copy + +TEXTDET_MASK = 0 +TEXTDET_DET = 1 +TEXTDET_INFERENCE = 2 + +class double_conv_up_c3(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch, act=True): + super(double_conv_up_c3, self).__init__() + self.conv = nn.Sequential( + C3(in_ch+mid_ch, mid_ch, act=act), + nn.ConvTranspose2d(mid_ch, out_ch, kernel_size=4, stride = 2, padding=1, bias=False), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.conv(x) + +class double_conv_c3(nn.Module): + def __init__(self, in_ch, out_ch, stride=1, act=True): + super(double_conv_c3, self).__init__() + if stride > 1 : + self.down = nn.AvgPool2d(2,stride=2) if stride > 1 else None + self.conv = C3(in_ch, out_ch, act=act) + + def forward(self, x): + if self.down is not None : + x = self.down(x) + x = self.conv(x) + return x + +class UnetHead(nn.Module): + def __init__(self, act=True) -> None: + + super(UnetHead, self).__init__() + self.down_conv1 = double_conv_c3(512, 512, 2, act=act) + self.upconv0 = double_conv_up_c3(0, 512, 256, act=act) + self.upconv2 = double_conv_up_c3(256, 512, 256, act=act) + self.upconv3 = double_conv_up_c3(0, 512, 256, act=act) + self.upconv4 = double_conv_up_c3(128, 256, 128, act=act) + self.upconv5 = double_conv_up_c3(64, 128, 64, act=act) + self.upconv6 = nn.Sequential( + nn.ConvTranspose2d(64, 1, kernel_size=4, stride = 2, padding=1, bias=False), + nn.Sigmoid() + ) + + def forward(self, f160, f80, f40, f20, f3, forward_mode=TEXTDET_MASK): + # input: 640@3 + d10 = self.down_conv1(f3) # 512@10 + u20 = self.upconv0(d10) # 256@10 + u40 = self.upconv2(torch.cat([f20, u20], dim = 1)) # 256@40 + + if forward_mode == TEXTDET_DET: + return f80, f40, u40 + else: + u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80 + u160 = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160 + u320 = self.upconv5(torch.cat([f160, u160], dim = 1)) # 64@320 + mask = self.upconv6(u320) + if forward_mode == TEXTDET_MASK: + return mask + else: + return mask, [f80, f40, u40] + + def init_weight(self, init_func): + self.apply(init_func) + +class DBHead(nn.Module): + def __init__(self, in_channels, k = 50, shrink_with_sigmoid=True, act=True): + super().__init__() + self.k = k + self.shrink_with_sigmoid = shrink_with_sigmoid + self.upconv3 = double_conv_up_c3(0, 512, 256, act=act) + self.upconv4 = double_conv_up_c3(128, 256, 128, act=act) + self.conv = nn.Sequential( + nn.Conv2d(128, in_channels, 1), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True) + ) + self.binarize = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 4, 3, padding=1), + nn.BatchNorm2d(in_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), + nn.BatchNorm2d(in_channels // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, 1, 2, 2) + ) + self.thresh = self._init_thresh(in_channels) + + def forward(self, f80, f40, u40, shrink_with_sigmoid=True, step_eval=False): + shrink_with_sigmoid = self.shrink_with_sigmoid + u80 = self.upconv3(torch.cat([f40, u40], dim = 1)) # 256@80 + x = self.upconv4(torch.cat([f80, u80], dim = 1)) # 128@160 + x = self.conv(x) + threshold_maps = self.thresh(x) + x = self.binarize(x) + shrink_maps = torch.sigmoid(x) + + if self.training: + binary_maps = self.step_function(shrink_maps, threshold_maps) + if shrink_with_sigmoid: + return torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1) + else: + return torch.cat((shrink_maps, threshold_maps, binary_maps, x), dim=1) + else: + if step_eval: + return self.step_function(shrink_maps, threshold_maps) + else: + return torch.cat((shrink_maps, threshold_maps), dim=1) + + def init_weight(self, init_func): + self.apply(init_func) + + def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False): + in_channels = inner_channels + if serial: + in_channels += 1 + self.thresh = nn.Sequential( + nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias), + nn.BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), + nn.Sigmoid()) + return self.thresh + + def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False): + if smooth: + inter_out_channels = out_channels + if out_channels == 1: + inter_out_channels = in_channels + module_list = [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] + if out_channels == 1: + module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True)) + return nn.Sequential(module_list) + else: + return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) + + def step_function(self, x, y): + return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) + +class TextDetector(nn.Module): + def __init__(self, weights, map_location='cpu', forward_mode=TEXTDET_MASK, act=True): + super(TextDetector, self).__init__() + + yolov5s_backbone = load_yolov5_ckpt(weights=weights, map_location=map_location) + yolov5s_backbone.eval() + out_indices = [1, 3, 5, 7, 9] + yolov5s_backbone.out_indices = out_indices + yolov5s_backbone.model = yolov5s_backbone.model[:max(out_indices)+1] + self.act = act + self.seg_net = UnetHead(act=act) + self.backbone = yolov5s_backbone + self.dbnet = None + self.forward_mode = forward_mode + + def train_mask(self): + self.forward_mode = TEXTDET_MASK + self.backbone.eval() + self.seg_net.train() + + def initialize_db(self, unet_weights): + self.dbnet = DBHead(64, act=self.act) + self.seg_net.load_state_dict(torch.load(unet_weights, map_location='cpu')['weights']) + self.dbnet.init_weight(init_weights) + self.dbnet.upconv3 = copy.deepcopy(self.seg_net.upconv3) + self.dbnet.upconv4 = copy.deepcopy(self.seg_net.upconv4) + del self.seg_net.upconv3 + del self.seg_net.upconv4 + del self.seg_net.upconv5 + del self.seg_net.upconv6 + # del self.seg_net.conv_mask + + def train_db(self): + self.forward_mode = TEXTDET_DET + self.backbone.eval() + self.seg_net.eval() + self.dbnet.train() + + def forward(self, x): + forward_mode = self.forward_mode + with torch.no_grad(): + outs = self.backbone(x) + if forward_mode == TEXTDET_MASK: + return self.seg_net(*outs, forward_mode=forward_mode) + elif forward_mode == TEXTDET_DET: + with torch.no_grad(): + outs = self.seg_net(*outs, forward_mode=forward_mode) + return self.dbnet(*outs) + +def get_base_det_models(model_path, device='cpu', half=False, act='leaky'): + textdetector_dict = torch.load(model_path, map_location='cpu') + blk_det = load_yolov5_ckpt(textdetector_dict['blk_det'], map_location='cpu') + text_seg = UnetHead(act=act) + text_seg.load_state_dict(textdetector_dict['text_seg']) + text_det = DBHead(64, act=act) + text_det.load_state_dict(textdetector_dict['text_det']) + if half: + return blk_det.eval().to(device).half(), text_seg.eval().to(device).half(), text_det.eval().to(device).half() + return blk_det.eval().to(device), text_seg.eval().to(device), text_det.eval().to(device) + +class TextDetBase(nn.Module): + def __init__(self, model_path, device='cpu', half=False, fuse=False, act='leaky'): + super(TextDetBase, self).__init__() + self.blk_det, self.text_seg, self.text_det = get_base_det_models(model_path, device, half, act=act) + if fuse: + self.fuse() + + def fuse(self): + def _fuse(model): + for m in model.modules(): + if isinstance(m, (Conv)) and hasattr(m, 'bn'): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, 'bn') # remove batchnorm + m.forward = m.forward_fuse # update forward + return model + self.text_seg = _fuse(self.text_seg) + self.text_det = _fuse(self.text_det) + + def forward(self, features): + blks, features = self.blk_det(features, detect=True) + mask, features = self.text_seg(*features, forward_mode=TEXTDET_INFERENCE) + lines = self.text_det(*features, step_eval=False) + return blks, mask, lines \ No newline at end of file diff --git a/textblockdetector/textblock.py b/textblockdetector/textblock.py new file mode 100644 index 000000000..b4b04ec94 --- /dev/null +++ b/textblockdetector/textblock.py @@ -0,0 +1,375 @@ +from typing import List +import numpy as np +from shapely.geometry import Polygon +import math +import copy +from .utils.imgproc_utils import union_area, xywh2xyxypoly, rotate_polygons +import cv2 +import functools + +LANG_LIST = ['eng', 'ja', 'unknown'] +LANGCLS2IDX = {'eng': 0, 'ja': 1, 'unknown': 2} + +class TextBlock(object): + def __init__(self, xyxy: List, + lines: List = None, + language: str = 'unknown', + vertical: bool = False, + font_size: float = -1, + distance: List = None, + angle: int = -1, + vec: List = None, + norm: float = -1, + merged: bool = False, + weight: float = -1, + **kwargs) -> None: + self.xyxy = xyxy # boundingbox of textblock + if lines is not None: + self.lines = lines # polygons of textlines + else: + self.lines = [] + self.vertical = vertical # orientation of textlines + self.language = language + self.font_size = font_size + if distance is not None: # distance between textlines and "origin" + self.distance = np.array(distance, np.float64) + else: + self.distance = None + self.angle = angle # rotation angle of textlines + if vec is not None: # primary vector of textblock + self.vec = np.array(vec, np.float64) + else: + vec = None + self.norm = norm # primary norm of textblock + self.merged = merged + self.weight = weight + + self.structure = None + + self.text = list() + self.prob = None + self.fg_r = None + self.fg_g = None + self.fg_b = None + self.bg_r = None + self.bg_g = None + self.bg_b = None + + def adjust_bbox(self, with_bbox=False): + lines = np.array(self.lines) + if with_bbox: + self.xyxy[0] = min(lines[..., 0].min(), self.xyxy[0]) + self.xyxy[1] = min(lines[..., 1].min(), self.xyxy[1]) + self.xyxy[2] = max(lines[..., 0].max(), self.xyxy[2]) + self.xyxy[3] = max(lines[..., 1].max(), self.xyxy[3]) + else: + self.xyxy[0] = lines[..., 0].min() + self.xyxy[1] = lines[..., 1].min() + self.xyxy[2] = lines[..., 0].max() + self.xyxy[3] = lines[..., 1].max() + + def sort_lines(self): + if self.distance is not None: + idx = np.argsort(self.distance) + self.distance = self.distance[idx] + lines = np.array(self.lines, dtype=np.int32) + self.lines = lines[idx].tolist() + self.structure = self.structure[idx] + + def lines_array(self, dtype=np.float64): + return np.array(self.lines, dtype=dtype) + + @functools.cached_property + def aspect_ratio(self) -> float: + mini_rect = self.mini_rect + middle_pnts = (mini_rect[:, [1, 2, 3, 0]] + mini_rect) / 2 + norm_v = np.linalg.norm(middle_pnts[:, 2] - middle_pnts[:, 0]) + norm_h = np.linalg.norm(middle_pnts[:, 1] - middle_pnts[:, 3]) + return norm_v / norm_h + + @functools.cached_property + def mini_rect(self): + center = [self.xyxy[0]/2, self.xyxy[1]/2] + polygons = self.lines_array().reshape(-1, 8) + rotated_polygons = rotate_polygons(center, polygons, self.angle) + min_x = rotated_polygons[:, ::2].min() + min_y = rotated_polygons[:, 1::2].min() + max_x = rotated_polygons[:, ::2].max() + max_y = rotated_polygons[:, 1::2].max() + min_bbox = np.array([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]]) + min_bbox = rotate_polygons(center, min_bbox, -self.angle) + return min_bbox.reshape(-1, 4, 2) + + def __getattribute__(self, name: str): + if name == 'pts': + return self.lines_array() + # else: + return object.__getattribute__(self, name) + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + return self.lines[idx] + + def to_dict(self, extra_info=False): + blk_dict = copy.deepcopy(vars(self)) + if not extra_info: + blk_dict.pop('distance') + blk_dict.pop('weight') + blk_dict.pop('vec') + blk_dict.pop('norm') + return blk_dict + + def get_transformed_region(self, img, idx, textheight) -> np.ndarray : + [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure[idx]] + v_vec = l2a - l1a + h_vec = l1b - l2b + ratio = np.linalg.norm(v_vec) / np.linalg.norm(h_vec) + src_pts = self.pts[idx].astype(np.float32) + direction = 'v' if self.vertical else 'h' + if direction == 'h' : + h = int(textheight) + w = int(round(textheight / ratio)) + dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32) + M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) + region = cv2.warpPerspective(img, M, (w, h)) + return region + elif direction == 'v' : + w = int(textheight) + h = int(round(textheight * ratio)) + dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32) + M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) + region = cv2.warpPerspective(img, M, (w, h)) + region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE) + # cv2.imshow('region'+str(idx), region) + # cv2.waitKey(0) + return region + + def get_text(self): + return ' '.join(self.text) + +def sort_textblk_list(blk_list: List[TextBlock], im_w: int, im_h: int) -> List[TextBlock]: + if len(blk_list) == 0: + return blk_list + num_ja = 0 + xyxy = [] + for blk in blk_list: + if blk.language == 'ja': + num_ja += 1 + xyxy.append(blk.xyxy) + xyxy = np.array(xyxy) + flip_lr = num_ja > len(blk_list) / 2 + im_oriw = im_w + if im_w > im_h: + im_w /= 2 + num_gridy, num_gridx = 4, 3 + img_area = im_h * im_w + center_x = (xyxy[:, 0] + xyxy[:, 2]) / 2 + if flip_lr: + if im_w != im_oriw: + center_x = im_oriw - center_x + else: + center_x = im_w - center_x + grid_x = (center_x / im_w * num_gridx).astype(np.int32) + center_y = (xyxy[:, 1] + xyxy[:, 3]) / 2 + grid_y = (center_y / im_h * num_gridy).astype(np.int32) + grid_indices = grid_y * num_gridx + grid_x + grid_weights = grid_indices * img_area + 1.2 * (center_x - grid_x * im_w / num_gridx) + (center_y - grid_y * im_h / num_gridy) + if im_w != im_oriw: + grid_weights[np.where(grid_x >= num_gridx)] += img_area * num_gridy * num_gridx + + for blk, weight in zip(blk_list, grid_weights): + blk.weight = weight + blk_list.sort(key=lambda blk: blk.weight) + return blk_list + +def examine_textblk(blk: TextBlock, im_w: int, im_h: int, eval_orientation: bool, sort: bool = False) -> None: + lines = blk.lines_array() + middle_pnts = (lines[:, [1, 2, 3, 0]] + lines) / 2 + vec_v = middle_pnts[:, 2] - middle_pnts[:, 0] # vertical vectors of textlines + vec_h = middle_pnts[:, 1] - middle_pnts[:, 3] # horizontal vectors of textlines + # if sum of vertical vectors is longer, then text orientation is vertical, and vice versa. + center_pnts = (lines[:, 0] + lines[:, 2]) / 2 + v = np.sum(vec_v, axis=0) + h = np.sum(vec_h, axis=0) + norm_v, norm_h = np.linalg.norm(v), np.linalg.norm(h) + vertical = eval_orientation and norm_v > norm_h + # calcuate distance between textlines and origin + if vertical: + primary_vec, primary_norm = v, norm_v + distance_vectors = center_pnts - np.array([[im_w, 0]], dtype=np.float64) # vertical manga text is read from right to left, so origin is (imw, 0) + font_size = int(round(norm_h / len(lines))) + else: + primary_vec, primary_norm = h, norm_h + distance_vectors = center_pnts - np.array([[0, 0]], dtype=np.float64) + font_size = int(round(norm_v / len(lines))) + + rotation_angle = int(math.atan2(primary_vec[1], primary_vec[0]) / math.pi * 180) # rotation angle of textlines + distance = np.linalg.norm(distance_vectors, axis=1) # distance between textlinecenters and origin + rad_matrix = np.arccos(np.einsum('ij, j->i', distance_vectors, primary_vec) / (distance * primary_norm)) + distance = np.abs(np.sin(rad_matrix) * distance) + blk.lines = lines.astype(np.int32).tolist() + blk.distance = distance + blk.angle = rotation_angle + blk.font_size = font_size + blk.vertical = vertical + blk.vec = primary_vec + blk.norm = primary_norm + blk.structure = middle_pnts + if sort: + blk.sort_lines() + +def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.3, distance_tol=2) -> bool: + if blk2.merged: + return False + fntsize_div = blk.font_size / blk2.font_size + num_l1, num_l2 = len(blk), len(blk2) + fntsz_avg = (blk.font_size * num_l1 + blk2.font_size * num_l2) / (num_l1 + num_l2) + vec_prod = blk.vec @ blk2.vec + vec_sum = blk.vec + blk2.vec + cos_vec = vec_prod / blk.norm / blk2.norm + distance = blk2.distance[-1] - blk.distance[-1] + distance_p1 = np.linalg.norm(np.array(blk2.lines[-1][0]) - np.array(blk.lines[-1][0])) + l1, l2 = Polygon(blk.lines[-1]), Polygon(blk2.lines[-1]) + if not l1.intersects(l2): + if fntsize_div > fntsize_tol or 1 / fntsize_div > fntsize_tol: + return False + if abs(cos_vec) < 0.866: # cos30 + return False + if distance > distance_tol * fntsz_avg or distance_p1 > fntsz_avg * 2.5: + return False + # merge + blk.lines.append(blk2.lines[0]) + blk.vec = vec_sum + blk.angle = int(round(np.rad2deg(math.atan2(vec_sum[1], vec_sum[0])))) + blk.norm = np.linalg.norm(vec_sum) + blk.distance = np.append(blk.distance, blk2.distance[-1]) + blk.font_size = fntsz_avg + blk2.merged = True + return True + +def merge_textlines(blk_list: List[TextBlock]) -> List[TextBlock]: + if len(blk_list) < 2: + return blk_list + blk_list.sort(key=lambda blk: blk.distance[0]) + merged_list = list() + for ii, current_blk in enumerate(blk_list): + if current_blk.merged: + continue + for jj, blk in enumerate(blk_list[ii+1:]): + try_merge_textline(current_blk, blk) + merged_list.append(current_blk) + for blk in merged_list: + blk.adjust_bbox(with_bbox=False) + return merged_list + +def split_textblk(blk: TextBlock): + font_size, distance, lines = blk.font_size, blk.distance, blk.lines_array() + distance_tol = font_size * 2 + current_blk = copy.deepcopy(blk) + current_blk.lines = [lines[0]] + sub_blk_list = [current_blk] + textblock_splitted = False + for jj, line in enumerate(lines[1:]): + l1, l2 = Polygon(lines[jj]), Polygon(line) + split = False + if not l1.intersects(l2): + line_disance = distance[jj+1] - distance[jj] + if line_disance > distance_tol: + split = True + else: + if blk.vertical and abs(abs(blk.angle) - 90) < 10: + split = abs(lines[jj][0][1] - line[0][1]) > font_size + if split: + current_blk = copy.deepcopy(current_blk) + current_blk.lines = [line] + sub_blk_list.append(current_blk) + else: + current_blk.lines.append(line) + if len(sub_blk_list) > 1: + textblock_splitted = True + for current_blk in sub_blk_list: + current_blk.adjust_bbox(with_bbox=False) + return textblock_splitted, sub_blk_list + +def group_output(blks, lines, im_w, im_h, mask=None, sort_blklist=True) -> List[TextBlock]: + blk_list, scattered_lines = [], {'ver': [], 'hor': []} + for bbox, cls, conf in zip(*blks): + blk_list.append(TextBlock(bbox, language=LANG_LIST[cls])) + + # step1: filter & assign lines to textblocks + bbox_score_thresh = 0.4 + mask_score_thresh = 0.1 + for ii, line in enumerate(lines): + bx1, bx2 = line[:, 0].min(), line[:, 0].max() + by1, by2 = line[:, 1].min(), line[:, 1].max() + bbox_score, bbox_idx = -1, -1 + line_area = (by2-by1) * (bx2-bx1) + for jj, blk in enumerate(blk_list): + score = union_area(blk.xyxy, [bx1, by1, bx2, by2]) / line_area + if bbox_score < score: + bbox_score = score + bbox_idx = jj + if bbox_score > bbox_score_thresh: + blk_list[bbox_idx].lines.append(line) + else: # if no textblock was assigned, check whether there is "enough" textmask + if mask is not None: + mask_score = mask[by1: by2, bx1: bx2].mean() / 255 + if mask_score < mask_score_thresh: + continue + blk = TextBlock([bx1, by1, bx2, by2], [line]) + examine_textblk(blk, im_w, im_h, True, sort=False) + if blk.vertical: + scattered_lines['ver'].append(blk) + else: + scattered_lines['hor'].append(blk) + + # step2: filter textblocks, sort & split textlines + final_blk_list = list() + for ii, blk in enumerate(blk_list): + # filter textblocks + if len(blk.lines) == 0: + bx1, by1, bx2, by2 = blk.xyxy + if mask is not None: + mask_score = mask[by1: by2, bx1: bx2].mean() / 255 + if mask_score < mask_score_thresh: + continue + xywh = np.array([[bx1, by1, bx2-bx1, by2-by1]]) + blk.lines = xywh2xyxypoly(xywh).reshape(-1, 4, 2).tolist() + lines = blk.lines_array() + eval_orientation = blk.language != 'eng' + examine_textblk(blk, im_w, im_h, eval_orientation, sort=True) + # split manga text if there is a distance gap + textblock_splitted = blk.language == 'ja' and len(blk.lines) > 1 + if textblock_splitted: + textblock_splitted, sub_blk_list = split_textblk(blk) + else: + sub_blk_list = [blk] + # modify textblock to fit its textlines + if not textblock_splitted: + for blk in sub_blk_list: + blk.adjust_bbox(with_bbox=True) + final_blk_list += sub_blk_list + + # step3: merge scattered lines, sort textblocks by "grid" + final_blk_list += merge_textlines(scattered_lines['hor']) + final_blk_list += merge_textlines(scattered_lines['ver']) + if sort_blklist: + final_blk_list = sort_textblk_list(final_blk_list, im_w, im_h) + return final_blk_list + +def visualize_textblocks(canvas, blk_list: List[TextBlock]): + lw = max(round(sum(canvas.shape) / 2 * 0.003), 2) # line width + for ii, blk in enumerate(blk_list): + bx1, by1, bx2, by2 = blk.xyxy + cv2.rectangle(canvas, (bx1, by1), (bx2, by2), (127, 255, 127), lw) + lines = blk.lines_array(dtype=np.int32) + for jj, line in enumerate(lines): + cv2.putText(canvas, str(jj), line[0], cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,127,0), 1) + cv2.polylines(canvas, [line], True, (0,127,255), 2) + cv2.polylines(canvas, [blk.mini_rect], True, (127,127,0), 2) + center = [int((bx1 + bx2)/2), int((by1 + by2)/2)] + cv2.putText(canvas, str(blk.angle), center, cv2.FONT_HERSHEY_SIMPLEX, 1, (127,127,255), 2) + cv2.putText(canvas, str(ii), (bx1, by1 + lw + 2), 0, lw / 3, (255,127,127), max(lw-1, 1), cv2.LINE_AA) + return canvas \ No newline at end of file diff --git a/textblockdetector/textmask.py b/textblockdetector/textmask.py new file mode 100644 index 000000000..7efc70abb --- /dev/null +++ b/textblockdetector/textmask.py @@ -0,0 +1,170 @@ +from os import stat +from typing import List +import cv2 +import numpy as np +from .textblock import TextBlock +from .utils.imgproc_utils import draw_connected_labels, expand_textwindow, union_area + +WHITE = (255, 255, 255) +BLACK = (0, 0, 0) +LANG_ENG = 0 +LANG_JPN = 1 + +REFINEMASK_INPAINT = 0 +REFINEMASK_ANNOTATION = 1 + +def get_topk_color(color_list, bins, k=3, color_var=10, bin_tol=0.001): + idx = np.argsort(bins * -1) + color_list, bins = color_list[idx], bins[idx] + top_colors = [color_list[0]] + bin_tol = np.sum(bins) * bin_tol + if len(color_list) > 1: + for color, bin in zip(color_list[1:], bins[1:]): + if np.abs(np.array(top_colors) - color).min() > color_var: + top_colors.append(color) + if len(top_colors) >= k or bin < bin_tol: + break + return top_colors + +def minxor_thresh(threshed, mask, dilate=False): + neg_threshed = 255 - threshed + e_size = 1 + if dilate: + element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size)) + neg_threshed = cv2.dilate(neg_threshed, element, iterations=1) + threshed = cv2.dilate(threshed, element, iterations=1) + neg_xor_sum = cv2.bitwise_xor(neg_threshed, mask).sum() + xor_sum = cv2.bitwise_xor(threshed, mask).sum() + if neg_xor_sum < xor_sum: + return neg_threshed, neg_xor_sum + else: + return threshed, xor_sum + +def get_otsuthresh_masklist(img, pred_mask, per_channel=False) -> List[np.ndarray]: + channels = [img[..., 0], img[..., 1], img[..., 2]] + mask_list = [] + for c in channels: + _, threshed = cv2.threshold(c, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY) + threshed, xor_sum = minxor_thresh(threshed, pred_mask, dilate=False) + mask_list.append([threshed, xor_sum]) + mask_list.sort(key=lambda x: x[1]) + if per_channel: + return mask_list + else: + return [mask_list[0]] + +def get_topk_masklist(im_grey, pred_mask): + if len(im_grey.shape) == 3 and im_grey.shape[-1] == 3: + im_grey = cv2.cvtColor(im_grey, cv2.COLOR_BGR2GRAY) + msk = np.ascontiguousarray(pred_mask) + candidate_grey_px = im_grey[np.where(cv2.erode(msk, np.ones((3,3), np.uint8), iterations=1) > 127)] + bin, his = np.histogram(candidate_grey_px, bins=255) + topk_color = get_topk_color(his, bin, color_var=10, k=3) + color_range = 30 + mask_list = list() + for ii, color in enumerate(topk_color): + c_top = min(color+color_range, 255) + c_bottom = c_top - 2 * color_range + threshed = cv2.inRange(im_grey, c_bottom, c_top) + threshed, xor_sum = minxor_thresh(threshed, msk) + mask_list.append([threshed, xor_sum]) + return mask_list + +def merge_mask_list(mask_list, pred_mask, blk: TextBlock = None, pred_thresh=30, text_window=None, filter_with_lines=False, refine_mode=REFINEMASK_INPAINT): + mask_list.sort(key=lambda x: x[1]) + linemask = None + if blk is not None and filter_with_lines: + linemask = np.zeros_like(pred_mask) + lines = blk.lines_array(dtype=np.int64) + for line in lines: + line[..., 0] -= text_window[0] + line[..., 1] -= text_window[1] + cv2.fillPoly(linemask, [line], 255) + linemask = cv2.dilate(linemask, np.ones((3, 3), np.uint8), iterations=3) + + if pred_thresh > 0: + e_size = 1 + element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size)) + pred_mask = cv2.erode(pred_mask, element, iterations=1) + _, pred_mask = cv2.threshold(pred_mask, 60, 255, cv2.THRESH_BINARY) + connectivity = 8 + mask_merged = np.zeros_like(pred_mask) + for ii, (candidate_mask, xor_sum) in enumerate(mask_list): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(candidate_mask, connectivity, cv2.CV_16U) + for label_index, stat, centroid in zip(range(num_labels), stats, centroids): + if label_index != 0: # skip background label + x, y, w, h, area = stat + if w * h < 3: + continue + x1, y1, x2, y2 = x, y, x+w, y+h + label_local = labels[y1: y2, x1: x2] + label_cordinates = np.where(label_local==label_index) + tmp_merged = np.zeros_like(label_local, np.uint8) + tmp_merged[label_cordinates] = 255 + tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged) + xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum() + xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum() + if xor_merged < xor_origin: + mask_merged[y1: y2, x1: x2] = tmp_merged + + if refine_mode == REFINEMASK_INPAINT: + mask_merged = cv2.dilate(mask_merged, np.ones((5, 5), np.uint8), iterations=1) + # fill holes + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(255-mask_merged, connectivity, cv2.CV_16U) + sorted_area = np.sort(stats[:, -1]) + if len(sorted_area) > 1: + area_thresh = sorted_area[-2] + else: + area_thresh = sorted_area[-1] + for label_index, stat, centroid in zip(range(num_labels), stats, centroids): + x, y, w, h, area = stat + if area < area_thresh: + x1, y1, x2, y2 = x, y, x+w, y+h + label_local = labels[y1: y2, x1: x2] + label_cordinates = np.where(label_local==label_index) + tmp_merged = np.zeros_like(label_local, np.uint8) + tmp_merged[label_cordinates] = 255 + tmp_merged = cv2.bitwise_or(mask_merged[y1: y2, x1: x2], tmp_merged) + xor_merged = cv2.bitwise_xor(tmp_merged, pred_mask[y1: y2, x1: x2]).sum() + xor_origin = cv2.bitwise_xor(mask_merged[y1: y2, x1: x2], pred_mask[y1: y2, x1: x2]).sum() + if xor_merged < xor_origin: + mask_merged[y1: y2, x1: x2] = tmp_merged + return mask_merged + + +def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined: np.ndarray, blk_list: List[TextBlock], refine_mode=REFINEMASK_INPAINT): + mask_pred[np.where(mask_refined > 30)] = 0 + _, pred_mask_t = cv2.threshold(mask_pred, 30, 255, cv2.THRESH_BINARY) + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(pred_mask_t, 4, cv2.CV_16U) + valid_labels = np.where(stats[:, -1] > 50)[0] + seg_blk_list = [] + if len(valid_labels) > 0: + for lab_index in valid_labels[1:]: + x, y, w, h, area = stats[lab_index] + bx1, by1 = x, y + bx2, by2 = x+w, y+h + bbox = [bx1, by1, bx2, by2] + bbox_score = -1 + for blk in blk_list: + bbox_s = union_area(blk.xyxy, bbox) + if bbox_s > bbox_score: + bbox_score = bbox_s + if bbox_score / w / h < 0.5: + seg_blk_list.append(TextBlock(bbox)) + if len(seg_blk_list) > 0: + mask_refined = cv2.bitwise_or(mask_refined, refine_mask(img, mask_pred, seg_blk_list, refine_mode=refine_mode)) + return mask_refined + + +def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[TextBlock], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray: + mask_refined = np.zeros_like(pred_mask) + for blk in blk_list: + bx1, by1, bx2, by2 = expand_textwindow(img.shape, blk.xyxy, expand_r=16) + im = np.ascontiguousarray(img[by1: by2, bx1: bx2]) + msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2]) + mask_list = get_topk_masklist(im, msk) + mask_list += get_otsuthresh_masklist(im, msk, per_channel=False) + mask_merged = merge_mask_list(mask_list, msk, blk=blk, text_window=[bx1, by1, bx2, by2], refine_mode=refine_mode) + mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged) + return mask_refined + diff --git a/textblockdetector/utils/db_utils.py b/textblockdetector/utils/db_utils.py new file mode 100644 index 000000000..603650ff8 --- /dev/null +++ b/textblockdetector/utils/db_utils.py @@ -0,0 +1,695 @@ +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon +from collections import namedtuple +import warnings +warnings.filterwarnings('ignore') + + +def iou_rotate(box_a, box_b, method='union'): + rect_a = cv2.minAreaRect(box_a) + rect_b = cv2.minAreaRect(box_b) + r1 = cv2.rotatedRectangleIntersection(rect_a, rect_b) + if r1[0] == 0: + return 0 + else: + inter_area = cv2.contourArea(r1[1]) + area_a = cv2.contourArea(box_a) + area_b = cv2.contourArea(box_b) + union_area = area_a + area_b - inter_area + if union_area == 0 or inter_area == 0: + return 0 + if method == 'union': + iou = inter_area / union_area + elif method == 'intersection': + iou = inter_area / min(area_a, area_b) + else: + raise NotImplementedError + return iou + +class SegDetectorRepresenter(): + def __init__(self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=1.5): + self.min_size = 3 + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + + def __call__(self, batch, pred, is_output_polygon=False): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + pred: + binary: text region segmentation map, with shape (N, H, W) + thresh: [if exists] thresh hold prediction with shape (N, H, W) + thresh_binary: [if exists] binarized with threshhold, (N, H, W) + ''' + pred = pred[:, 0, :, :] + segmentation = self.binarize(pred) + boxes_batch = [] + scores_batch = [] + for batch_index in range(pred.size(0)): + # height, width = batch['shape'][batch_index] + height, width = pred.shape[1], pred.shape[2] + if is_output_polygon: + boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) + else: + boxes, scores = self.boxes_from_bitmap(pred[batch_index], segmentation[batch_index], width, height) + boxes_batch.append(boxes) + scores_batch.append(scores) + return boxes_batch, scores_batch + + def binarize(self, pred): + return pred > self.thresh + + def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap.cpu().numpy() # The first channel + pred = pred.cpu().detach().numpy() + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:self.max_candidates]: + epsilon = 0.005 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + # _, sside = self.get_mini_boxes(contour) + # if sside < self.min_size: + # continue + score = self.box_score_fast(pred, contour.squeeze(1)) + if self.box_thresh > score: + continue + + if points.shape[0] > 2: + box = self.unclip(points, unclip_ratio=self.unclip_ratio) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < self.min_size + 2: + continue + + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box) + scores.append(score) + return boxes, scores + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (H, W), + whose values are binarized as {0, 1} + ''' + + assert len(_bitmap.shape) == 2 + bitmap = _bitmap.cpu().numpy() # The first channel + pred = pred.cpu().detach().numpy() + height, width = bitmap.shape + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + num_contours = min(len(contours), self.max_candidates) + boxes = np.zeros((num_contours, 4, 2), dtype=np.int16) + scores = np.zeros((num_contours,), dtype=np.float32) + + for index in range(num_contours): + contour = contours[index].squeeze(1) + points, sside = self.get_mini_boxes(contour) + # if sside < self.min_size: + # continue + if sside < 2: + continue + points = np.array(points) + score = self.box_score_fast(pred, contour) + # if self.box_thresh > score: + # continue + + box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + # if sside < 5: + # continue + box = np.array(box) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes[index, :, :] = box.astype(np.int16) + scores[index] = score + return boxes, scores + + def unclip(self, box, unclip_ratio=1.5): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + if bitmap.dtype == np.float16: + bitmap = bitmap.astype(np.float32) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + return self + + +class DetectionIoUEvaluator(object): + def __init__(self, is_output_polygon=False, iou_constraint=0.5, area_precision_constraint=0.5): + self.is_output_polygon = is_output_polygon + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + + evaluationLog = "" + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len( + gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > self.area_precision_constraint): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len( + detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + if self.is_output_polygon: + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + else: + # gtPols = np.float32(gtPols) + # detPols = np.float32(detPols) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = np.float32(gtPols[gtNum]) + pD = np.float32(detPols[detNum]) + iouMat[gtNum, detNum] = iou_rotate(pD, pG) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + \ + str(gtNum) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float( + detMatched) / numDetCare + + hmean = 0 if (precision + recall) == 0 else 2.0 * \ + precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'detMatched': detMatched, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + + methodMetrics = {'precision': methodPrecision, + 'recall': methodRecall, 'hmean': methodHmean} + + return methodMetrics + +class QuadMetric(): + def __init__(self, is_output_polygon=False): + self.is_output_polygon = is_output_polygon + self.evaluator = DetectionIoUEvaluator(is_output_polygon=is_output_polygon) + + def measure(self, batch, output, box_thresh=0.6): + ''' + batch: (image, polygons, ignore_tags + batch: a dict produced by dataloaders. + image: tensor of shape (N, C, H, W). + polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. + ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. + shape: the original shape of images. + filename: the original filenames of images. + output: (polygons, ...) + ''' + results = [] + gt_polyons_batch = batch['text_polys'] + ignore_tags_batch = batch['ignore_tags'] + pred_polygons_batch = np.array(output[0]) + pred_scores_batch = np.array(output[1]) + for polygons, pred_polygons, pred_scores, ignore_tags in zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch): + gt = [dict(points=np.int64(polygons[i]), ignore=ignore_tags[i]) for i in range(len(polygons))] + if self.is_output_polygon: + pred = [dict(points=pred_polygons[i]) for i in range(len(pred_polygons))] + else: + pred = [] + # print(pred_polygons.shape) + for i in range(pred_polygons.shape[0]): + if pred_scores[i] >= box_thresh: + # print(pred_polygons[i,:,:].tolist()) + pred.append(dict(points=pred_polygons[i, :, :].astype(np.int))) + # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])] + results.append(self.evaluator.evaluate_image(gt, pred)) + return results + + def validate_measure(self, batch, output, box_thresh=0.6): + return self.measure(batch, output, box_thresh) + + def evaluate_measure(self, batch, output): + return self.measure(batch, output), np.linspace(0, batch['image'].shape[0]).tolist() + + def gather_measure(self, raw_metrics): + raw_metrics = [image_metrics + for batch_metrics in raw_metrics + for image_metrics in batch_metrics] + + result = self.evaluator.combine_results(raw_metrics) + + precision = AverageMeter() + recall = AverageMeter() + fmeasure = AverageMeter() + + precision.update(result['precision'], n=len(raw_metrics)) + recall.update(result['recall'], n=len(raw_metrics)) + fmeasure_score = 2 * precision.val * recall.val / (precision.val + recall.val + 1e-8) + fmeasure.update(fmeasure_score) + + return { + 'precision': precision, + 'recall': recall, + 'fmeasure': fmeasure + } + +def shrink_polygon_py(polygon, shrink_ratio): + """ + 对框进行缩放,返回去的比例为1/shrink_ratio 即可 + """ + cx = polygon[:, 0].mean() + cy = polygon[:, 1].mean() + polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio + polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio + return polygon + + +def shrink_polygon_pyclipper(polygon, shrink_ratio): + from shapely.geometry import Polygon + import pyclipper + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + if shrinked == []: + shrinked = np.array(shrinked) + else: + shrinked = np.array(shrinked[0]).reshape(-1, 2) + return shrinked + +class MakeShrinkMap(): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, min_text_size=4, shrink_ratio=0.4, shrink_type='pyclipper'): + shrink_func_dict = {'py': shrink_polygon_py, 'pyclipper': shrink_polygon_pyclipper} + self.shrink_func = shrink_func_dict[shrink_type] + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + image = data['imgs'] + text_polys = data['text_polys'] + ignore_tags = data['ignore_tags'] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + shrinked = self.shrink_func(polygon, self.shrink_ratio) + if shrinked.size == 0: + cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) + + data['shrink_map'] = gt + data['shrink_mask'] = mask + return data + + def validate_polygons(self, polygons, ignore_tags, h, w): + ''' + polygons (numpy.array, required): of shape (num_instances, num_points, 2) + ''' + if len(polygons) == 0: + return polygons, ignore_tags + assert len(polygons) == len(ignore_tags) + for polygon in polygons: + polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) + + for i in range(len(polygons)): + area = self.polygon_area(polygons[i]) + if abs(area) < 1: + ignore_tags[i] = True + if area > 0: + polygons[i] = polygons[i][::-1, :] + return polygons, ignore_tags + + def polygon_area(self, polygon): + return cv2.contourArea(polygon) + + +class MakeBorderMap(): + def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7): + self.shrink_ratio = shrink_ratio + self.thresh_min = thresh_min + self.thresh_max = thresh_max + + def __call__(self, data: dict) -> dict: + """ + 从scales中随机选择一个尺度,对图片和文本框进行缩放 + :param data: {'imgs':,'text_polys':,'texts':,'ignore_tags':} + :return: + """ + im = data['imgs'] + text_polys = data['text_polys'] + ignore_tags = data['ignore_tags'] + + canvas = np.zeros(im.shape[:2], dtype=np.float32) + mask = np.zeros(im.shape[:2], dtype=np.float32) + + for i in range(len(text_polys)): + if ignore_tags[i]: + continue + self.draw_border_map(text_polys[i], canvas, mask=mask) + canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min + + data['threshold_map'] = canvas + data['threshold_mask'] = mask + return data + + def draw_border_map(self, polygon, canvas, mask): + polygon = np.array(polygon) + assert polygon.ndim == 2 + assert polygon.shape[1] == 2 + + polygon_shape = Polygon(polygon) + if polygon_shape.area <= 0: + return + distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + + padded_polygon = np.array(padding.Execute(distance)[0]) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) + + distance_map = np.zeros( + (polygon.shape[0], height, width), dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ + ymin_valid - ymin:ymax_valid - ymax + height, + xmin_valid - xmin:xmax_valid - xmax + width], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) + + def distance(self, xs, ys, point_1, point_2): + ''' + compute the distance from point to a line + ys: coordinates in the first axis + xs: coordinates in the second axis + point_1, point_2: (x, y), the end of the line + ''' + height, width = xs.shape[:2] + square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) + square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) + square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) + + cosin = (square_distance - square_distance_1 - square_distance_2) / (2 * np.sqrt(square_distance_1 * square_distance_2)) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + + result = np.sqrt(square_distance_1 * square_distance_2 * square_sin / square_distance) + result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin < 0] + return result + + def extend_line(self, point_1, point_2, result): + ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), + int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) + cv2.line(result, tuple(ex_point_1), tuple(point_1), 4096.0, 1, lineType=cv2.LINE_AA, shift=0) + ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), + int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) + cv2.line(result, tuple(ex_point_2), tuple(point_2), 4096.0, 1, lineType=cv2.LINE_AA, shift=0) + return ex_point_1, ex_point_2 \ No newline at end of file diff --git a/textblockdetector/utils/imgproc_utils.py b/textblockdetector/utils/imgproc_utils.py new file mode 100644 index 000000000..ee87ca9ef --- /dev/null +++ b/textblockdetector/utils/imgproc_utils.py @@ -0,0 +1,171 @@ +import numpy as np +import cv2 +import random + +def hex2bgr(hex): + gmask = 254 << 8 + rmask = 254 + b = hex >> 16 + g = (hex & gmask) >> 8 + r = hex & rmask + return np.stack([b, g, r]).transpose() + +def union_area(bboxa, bboxb): + x1 = max(bboxa[0], bboxb[0]) + y1 = max(bboxa[1], bboxb[1]) + x2 = min(bboxa[2], bboxb[2]) + y2 = min(bboxa[3], bboxb[3]) + if y2 < y1 or x2 < x1: + return -1 + return (y2 - y1) * (x2 - x1) + +def get_yololabel_strings(clslist, labellist): + content = '' + for cls, xywh in zip(clslist, labellist): + content += str(int(cls)) + ' ' + ' '.join([str(e) for e in xywh]) + '\n' + if len(content) != 0: + content = content[:-1] + return content + +def xywh2xyxypoly(xywh): + xyxypoly = np.tile(xywh[:, [0, 1]], 4) + xyxypoly[:, [2, 4]] += xywh[:, [2]] + xyxypoly[:, [5, 7]] += xywh[:, [3]] + return xyxypoly.astype(np.int64) + +def xyxy2yolo(xyxy, w: int, h: int, ordered=False): + if xyxy == [] or xyxy == np.array([]) or len(xyxy) == 0: + return None + if isinstance(xyxy, list): + xyxy = np.array(xyxy) + if len(xyxy.shape) == 1: + xyxy = np.array([xyxy]) + yolo = np.copy(xyxy).astype(np.float64) + yolo[:, [0, 2]] = yolo[:, [0, 2]] / w + yolo[:, [1, 3]] = yolo[:, [1, 3]] / h + yolo[:, [2, 3]] -= yolo[:, [0, 1]] + yolo[:, [0, 1]] += yolo[:, [2, 3]] / 2 + return yolo + +def yolo_xywh2xyxy(xywh: np.array, w: int, h: int, to_int=True): + if xywh is None: + return None + if len(xywh) == 0: + return None + if len(xywh.shape) == 1: + xywh = np.array([xywh]) + xywh[:, [0, 2]] *= w + xywh[:, [1, 3]] *= h + xywh[:, [0, 1]] -= xywh[:, [2, 3]] / 2 + xywh[:, [2, 3]] += xywh[:, [0, 1]] + return xywh.astype(np.int64) + +def rotate_polygons(center, polygons, rotation, new_center=None, to_int=True): + if new_center is None: + new_center = center + rotation = np.deg2rad(rotation) + s, c = np.sin(rotation), np.cos(rotation) + polygons = polygons.astype(np.float32) + + polygons[:, 1::2] -= center[1] + polygons[:, ::2] -= center[0] + rotated = np.copy(polygons) + rotated[:, 1::2] = polygons[:, 1::2] * c - polygons[:, ::2] * s + rotated[:, ::2] = polygons[:, 1::2] * s + polygons[:, ::2] * c + rotated[:, 1::2] += new_center[1] + rotated[:, ::2] += new_center[0] + if to_int: + return rotated.astype(np.int64) + return rotated + +def letterbox(im, new_shape=(640, 640), color=(0, 0, 0), auto=False, scaleFill=False, scaleup=True, stride=128): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if not isinstance(new_shape, tuple): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + # dw /= 2 # divide padding into 2 sides + # dh /= 2 + dh, dw = int(dh), int(dw) + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, 0, dh, 0, dw, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + +def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR): + shape = im.shape[:2] # current shape [height, width] + if not isinstance(new_shape, tuple): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=interpolation) + return im + +def expand_textwindow(img_size, xyxy, expand_r=8, shrink=False): + im_h, im_w = img_size[:2] + x1, y1 , x2, y2 = xyxy + w = x2 - x1 + h = y2 - y1 + paddings = int(round((max(h, w) * 0.25 + min(h, w) * 0.75) / expand_r)) + if shrink: + paddings *= -1 + x1, y1 = max(0, x1 - paddings), max(0, y1 - paddings) + x2, y2 = min(im_w-1, x2+paddings), min(im_h-1, y2+paddings) + return [x1, y1, x2, y2] + +def draw_connected_labels(num_labels, labels, stats, centroids, names="draw_connected_labels", skip_background=True): + labdraw = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8) + max_ind = 0 + if isinstance(num_labels, int): + num_labels = range(num_labels) + + # for ind, lab in enumerate((range(num_labels))): + for lab in num_labels: + if skip_background and lab == 0: + continue + randcolor = (random.randint(0,255), random.randint(0,255), random.randint(0,255)) + labdraw[np.where(labels==lab)] = randcolor + maxr, minr = 0.5, 0.001 + maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr + minarea = labdraw.shape[0] * labdraw.shape[1] * minr + + stat = stats[lab] + bboxarea = stat[2] * stat[3] + if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea: + pix = np.zeros((labels.shape[0], labels.shape[1]), dtype=np.uint8) + pix[np.where(labels==lab)] = 255 + + rect = cv2.minAreaRect(cv2.findNonZero(pix)) + box = np.int0(cv2.boxPoints(rect)) + labdraw = cv2.drawContours(labdraw, [box], 0, randcolor, 2) + labdraw = cv2.circle(labdraw, (int(centroids[lab][0]),int(centroids[lab][1])), radius=5, color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)), thickness=-1) + + cv2.imshow(names, labdraw) + return labdraw + diff --git a/textblockdetector/utils/io_utils.py b/textblockdetector/utils/io_utils.py new file mode 100644 index 000000000..00bb08862 --- /dev/null +++ b/textblockdetector/utils/io_utils.py @@ -0,0 +1,54 @@ +import os +import os.path as osp +import glob +from pathlib import Path +import cv2 +import numpy as np +import json + +IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg'] + +NP_BOOL_TYPES = (np.bool_, np.bool8) +NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64) +NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64) + +# https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.ScalarType): + if isinstance(obj, NP_BOOL_TYPES): + return bool(obj) + elif isinstance(obj, NP_FLOAT_TYPES): + return float(obj) + elif isinstance(obj, NP_INT_TYPES): + return int(obj) + return json.JSONEncoder.default(self, obj) + +def find_all_imgs(img_dir, abs_path=False): + imglist = list() + for filep in glob.glob(osp.join(img_dir, "*")): + filename = osp.basename(filep) + file_suffix = Path(filename).suffix + if file_suffix.lower() not in IMG_EXT: + continue + if abs_path: + imglist.append(filep) + else: + imglist.append(filename) + return imglist + +def imread(imgpath, read_type=cv2.IMREAD_COLOR): + # img = cv2.imread(imgpath, read_type) + # if img is None: + img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), read_type) + return img + +def imwrite(img_path, img, ext='.png'): + suffix = Path(img_path).suffix + if suffix != '': + img_path = img_path.replace(suffix, ext) + else: + img_path += ext + cv2.imencode(ext, img)[1].tofile(img_path) \ No newline at end of file diff --git a/textblockdetector/utils/weight_init.py b/textblockdetector/utils/weight_init.py new file mode 100644 index 000000000..8d43868c7 --- /dev/null +++ b/textblockdetector/utils/weight_init.py @@ -0,0 +1,103 @@ +import torch.nn as nn +import torch + +def constant_init(module, val, bias=0): + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +def xavier_init(module, gain=1, bias=0, distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def normal_init(module, mean=0, std=1, bias=0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def uniform_init(module, a=0, b=1, bias=0): + nn.init.uniform_(module.weight, a, b) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def kaiming_init(module, + a=0, + is_rnn=False, + mode='fan_in', + nonlinearity='leaky_relu', + bias=0, + distribution='normal'): + assert distribution in ['uniform', 'normal'] + if distribution == 'uniform': + if is_rnn: + for name, param in module.named_parameters(): + if 'bias' in name: + nn.init.constant_(param, bias) + elif 'weight' in name: + nn.init.kaiming_uniform_(param, + a=a, + mode=mode, + nonlinearity=nonlinearity) + else: + nn.init.kaiming_uniform_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) + + else: + if is_rnn: + for name, param in module.named_parameters(): + if 'bias' in name: + nn.init.constant_(param, bias) + elif 'weight' in name: + nn.init.kaiming_normal_(param, + a=a, + mode=mode, + nonlinearity=nonlinearity) + else: + nn.init.kaiming_normal_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) + + if not is_rnn and hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + + +def bilinear_kernel(in_channels, out_channels, kernel_size): + factor = (kernel_size + 1) // 2 + if kernel_size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = (torch.arange(kernel_size).reshape(-1, 1), + torch.arange(kernel_size).reshape(1, -1)) + filt = (1 - torch.abs(og[0] - center) / factor) * \ + (1 - torch.abs(og[1] - center) / factor) + weight = torch.zeros((in_channels, out_channels, + kernel_size, kernel_size)) + weight[range(in_channels), range(out_channels), :, :] = filt + return weight + + +def init_weights(m): + # for m in modules: + + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + xavier_init(m) + elif isinstance(m, (nn.LSTM, nn.LSTMCell)): + kaiming_init(m, is_rnn=True) + # elif isinstance(m, nn.ConvTranspose2d): + # m.weight.data.copy_(bilinear_kernel(m.in_channels, m.out_channels, 4)); diff --git a/textblockdetector/utils/yolov5_utils.py b/textblockdetector/utils/yolov5_utils.py new file mode 100644 index 000000000..146b8f7e3 --- /dev/null +++ b/textblockdetector/utils/yolov5_utils.py @@ -0,0 +1,240 @@ +import math +import torch +import torch.nn as nn +import pkg_resources as pkg +import torch.nn.functional as F +import cv2 +import numpy as np +import time +import torchvision + +def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) + # scales img(bs,3,y,x) by ratio constrained to gs-multiple + if ratio == 1.0: + return img + else: + h, w = img.shape[2:] + s = (int(h * ratio), int(w * ratio)) # new size + img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize + if not same_shape: # pad/crop img + h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) + return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = nn.Conv2d(conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True).requires_grad_(False).to(conv.weight.device) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchors.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + m.anchors[:] = m.anchors.flip(0) + +def initialize_weights(model): + for m in model.modules(): + t = type(m) + if t is nn.Conv2d: + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif t is nn.BatchNorm2d: + m.eps = 1e-3 + m.momentum = 0.03 + elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True + +def make_divisible(x, divisor): + # Returns nearest x divisible by divisor + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + +def intersect_dicts(da, db, exclude=()): + # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values + return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} + +def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False): + # Check version vs. required version + current, minimum = (pkg.parse_version(x) for x in (current, minimum)) + result = (current == minimum) if pinned else (current >= minimum) # bool + if hard: # assert min requirements met + assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' + else: + return result + +class Colors: + # Ultralytics color palette https://ultralytics.com/ + def __init__(self): + # hex = matplotlib.colors.TABLEAU_COLORS.values() + hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', + '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') + self.palette = [self.hex2rgb('#' + c) for c in hex] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order (PIL) + return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, + labels=(), max_det=300): + """Runs Non-Maximum Suppression (NMS) on inference results + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Checks + assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' + assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 5), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Apply finite constraint + # if not torch.isfinite(x).all(): + # x = x[torch.isfinite(x).all(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if i.shape[0] > max_det: # limit detections + i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') + break # time limit exceeded + + return output + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + +DEFAULT_LANG_LIST = ['eng', 'ja'] +def draw_bbox(pred, img, lang_list=None): + if lang_list is None: + lang_list = DEFAULT_LANG_LIST + lw = max(round(sum(img.shape) / 2 * 0.003), 2) # line width + pred = pred.astype(np.int32) + colors = Colors() + img = np.copy(img) + for ii, obj in enumerate(pred): + p1, p2 = (obj[0], obj[1]), (obj[2], obj[3]) + label = lang_list[obj[-1]] + str(ii+1) + cv2.rectangle(img, p1, p2, colors(obj[-1], bgr=True), lw, lineType=cv2.LINE_AA) + t_w, t_h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=lw)[0] + cv2.putText(img, label, (p1[0], p1[1] + t_h + 2), 0, lw / 3, colors(obj[-1], bgr=True), max(lw-1, 1), cv2.LINE_AA) + return img \ No newline at end of file diff --git a/textblockdetector/yolov5/common.py b/textblockdetector/yolov5/common.py new file mode 100644 index 000000000..89e3a70fa --- /dev/null +++ b/textblockdetector/yolov5/common.py @@ -0,0 +1,290 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Common modules +""" + +import json +import math +import platform +import warnings +from collections import OrderedDict, namedtuple +from copy import copy +from pathlib import Path + +import cv2 +import numpy as np +import pandas as pd +import requests +import torch +import torch.nn as nn +from PIL import Image +from torch.cuda import amp + +from ..utils.yolov5_utils import make_divisible, initialize_weights, check_anchor_order, check_version, fuse_conv_and_bn + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + if isinstance(act, bool): + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + elif isinstance(act, str): + if act == 'leaky': + self.act = nn.LeakyReLU(0.1, inplace=True) + elif act == 'relu': + self.act = nn.ReLU(inplace=True) + else: + self.act = None + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def forward_fuse(self, x): + return self.act(self.conv(x)) + + +class DWConv(Conv): + # Depth-wise convolution class + def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class TransformerLayer(nn.Module): + # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) + def __init__(self, c, num_heads): + super().__init__() + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) + + def forward(self, x): + x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x + x = self.fc2(self.fc1(x)) + x + return x + + +class TransformerBlock(nn.Module): + # Vision Transformer https://arxiv.org/abs/2010.11929 + def __init__(self, c1, c2, num_heads, num_layers): + super().__init__() + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) # learnable position embedding + self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers))) + self.c2 = c2 + + def forward(self, x): + if self.conv is not None: + x = self.conv(x) + b, _, w, h = x.shape + p = x.flatten(2).permute(2, 0, 1) + return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1, act=act) + self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.SiLU() + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act=True): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1, act=act) + self.cv2 = Conv(c1, c_, 1, 1, act=act) + self.cv3 = Conv(2 * c_, c2, 1, act=act) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n))) + # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class C3TR(C3): + # C3 module with TransformerBlock() + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + +class C3SPP(C3): + # C3 module with SPP() + def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = SPP(c_, c_, k) + + +class C3Ghost(C3): + # C3 module with GhostBottleneck() + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) + + +class SPP(nn.Module): + # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class SPPF(nn.Module): + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher + def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y1 = self.m(x) + y2 = self.m(y1) + return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + # self.contract = Contract(gain=2) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + # return self.conv(self.contract(x)) + + +class GhostConv(nn.Module): + # Ghost Convolution https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups + super().__init__() + c_ = c2 // 2 # hidden channels + self.cv1 = Conv(c1, c_, k, s, None, g, act) + self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) + + def forward(self, x): + y = self.cv1(x) + return torch.cat([y, self.cv2(y)], 1) + + +class GhostBottleneck(nn.Module): + # Ghost Bottleneck https://github.com/huawei-noah/ghostnet + def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride + super().__init__() + c_ = c2 // 2 + self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw + DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw + GhostConv(c_, c2, 1, 1, act=False)) # pw-linear + self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), + Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + + def forward(self, x): + return self.conv(x) + self.shortcut(x) + + +class Contract(nn.Module): + # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain' + s = self.gain + x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) + return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40) + + +class Expand(nn.Module): + # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' + s = self.gain + x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) + return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class Classify(nn.Module): + # Classification head, i.e. x(b,c1,20,20) to x(b,c2) + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) + self.flat = nn.Flatten() + + def forward(self, x): + z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list + return self.flat(self.conv(z)) # flatten to x(b,c2) + + diff --git a/textblockdetector/yolov5/yolo.py b/textblockdetector/yolov5/yolo.py new file mode 100644 index 000000000..e7e0a454e --- /dev/null +++ b/textblockdetector/yolov5/yolo.py @@ -0,0 +1,311 @@ +from operator import mod +from cv2 import imshow +# from utils.yolov5_utils import scale_img +from copy import deepcopy +from .common import * + +class Detect(nn.Module): + stride = None # strides computed during build + onnx_dynamic = False # ONNX export parameter + + def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 # number of outputs per anchor + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid + self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + self.inplace = inplace # use in-place ops (e.g. slice assignment) + + def forward(self, x): + z = [] # inference output + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) + + y = x[i].sigmoid() + if self.inplace: + y[..., 0:2] = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 + xy = (y[..., 0:2] * 2 - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + def _make_grid(self, nx=20, ny=20, i=0): + d = self.anchors[i].device + if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility + yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij') + else: + yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)]) + grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float() + anchor_grid = (self.anchors[i].clone() * self.stride[i]) \ + .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float() + return grid, anchor_grid + +class Model(nn.Module): + def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes + super().__init__() + self.out_indices = None + if isinstance(cfg, dict): + self.yaml = cfg # model dict + else: # is *.yaml + import yaml # for torch hub + self.yaml_file = Path(cfg).name + with open(cfg, encoding='ascii', errors='ignore') as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels + if nc and nc != self.yaml['nc']: + # LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml['nc'] = nc # override yaml value + if anchors: + # LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}') + self.yaml['anchors'] = round(anchors) # override yaml value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml['nc'])] # default names + self.inplace = self.yaml.get('inplace', True) + + # Build strides, anchors + m = self.model[-1] # Detect() + # with torch.no_grad(): + if isinstance(m, Detect): + s = 256 # 2x min stride + m.inplace = self.inplace + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + # Init weights, biases + initialize_weights(self) + + def forward(self, x, augment=False, profile=False, visualize=False, detect=False): + # if augment: + # return self._forward_augment(x) # augmented inference, None + return self._forward_once(x, profile, visualize, detect=detect) # single-scale inference, train + + # def _forward_augment(self, x): + # img_size = x.shape[-2:] # height, width + # s = [1, 0.83, 0.67] # scales + # f = [None, 3, None] # flips (2-ud, 3-lr) + # y = [] # outputs + # for si, fi in zip(s, f): + # xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + # yi = self._forward_once(xi)[0] # forward + # # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save + # yi = self._descale_pred(yi, fi, si, img_size) + # y.append(yi) + # y = self._clip_augmented(y) # clip augmented tails + # return torch.cat(y, 1), None # augmented inference, train + + def _forward_once(self, x, profile=False, visualize=False, detect=False): + y, dt = [], [] # outputs + z = [] + for ii, m in enumerate(self.model): + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if self.out_indices is not None: + if m.i in self.out_indices: + z.append(x) + if self.out_indices is not None: + if detect: + return x, z + else: + return z + else: + return x + + def _descale_pred(self, p, flips, scale, img_size): + # de-scale predictions following augmented inference (inverse operation) + if self.inplace: + p[..., :4] /= scale # de-scale + if flips == 2: + p[..., 1] = img_size[0] - p[..., 1] # de-flip ud + elif flips == 3: + p[..., 0] = img_size[1] - p[..., 0] # de-flip lr + else: + x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + p = torch.cat((x, y, wh, p[..., 4:]), -1) + return p + + def _clip_augmented(self, y): + # Clip YOLOv5 augmented inference tails + nl = self.model[-1].nl # number of detection layers (P3-P5) + g = sum(4 ** x for x in range(nl)) # grid points + e = 1 # exclude layer count + i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices + y[0] = y[0][:, :-i] # large + i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices + y[-1] = y[-1][:, i:] # small + return y + + def _profile_one_layer(self, m, x, dt): + c = isinstance(m, Detect) # is final layer, copy input as inplace fix + for _ in range(10): + m(x.copy() if c else x) + + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + for m in self.model.modules(): + if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, 'bn') # remove batchnorm + m.forward = m.forward_fuse # update forward + # self.info() + return self + + # def info(self, verbose=False, img_size=640): # print model information + # model_info(self, verbose, img_size) + + def _apply(self, fn): + # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + self = super()._apply(fn) + m = self.model[-1] # Detect() + if isinstance(m, Detect): + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + if isinstance(m.anchor_grid, list): + m.anchor_grid = list(map(fn, m.anchor_grid)) + return self + +def parse_model(d, ch): # model_dict, input_channels(3) + # LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}") + anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except NameError: + pass + + n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]: + c1, c2 = ch[f], args[0] + if c2 != no: # if not output + c2 = make_divisible(c2 * gw, 8) + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3, C3TR, C3Ghost]: + args.insert(2, n) # number of repeats + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[x] for x in f) + elif m is Detect: + args.append([ch[x] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + elif m is Contract: + c2 = ch[f] * args[0] ** 2 + elif m is Expand: + c2 = ch[f] // args[0] ** 2 + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace('__main__.', '') # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + # LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + if i == 0: + ch = [] + ch.append(c2) + return nn.Sequential(*layers), sorted(save) + +def load_yolov5(weights, map_location='cuda', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]): + if isinstance(weights, str): + ckpt = torch.load(weights, map_location=map_location) # load + else: + ckpt = weights + + if fuse: + model = ckpt['model'].float().fuse().eval() # FP32 model + else: + model = ckpt['model'].float().eval() # without layer fuse + + # Compatibility updates + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + m.inplace = inplace # pytorch 1.7.0 compatibility + if type(m) is Detect: + if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility + delattr(m, 'anchor_grid') + setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) + elif type(m) is Conv: + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + model.out_indices = out_indices + return model + +@torch.no_grad() +def load_yolov5_ckpt(weights, map_location='cpu', fuse=True, inplace=True, out_indices=[1, 3, 5, 7, 9]): + if isinstance(weights, str): + ckpt = torch.load(weights, map_location=map_location) # load + else: + ckpt = weights + + model = Model(ckpt['cfg']) + model.load_state_dict(ckpt['weights'], strict=True) + + if fuse: + model = model.float().fuse().eval() # FP32 model + else: + model = model.float().eval() # without layer fuse + + # Compatibility updates + for m in model.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + m.inplace = inplace # pytorch 1.7.0 compatibility + if type(m) is Detect: + if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility + delattr(m, 'anchor_grid') + setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) + elif type(m) is Conv: + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + model.out_indices = out_indices + return model \ No newline at end of file diff --git a/translate_demo.py b/translate_demo.py index cafddfb91..4ebc5e6e3 100755 --- a/translate_demo.py +++ b/translate_demo.py @@ -23,6 +23,7 @@ parser.add_argument('--text-mag-ratio', default=1, type=int, help='text rendering magnification ratio, larger means higher quality') parser.add_argument('--translator', default='google', type=str, help='language translator') parser.add_argument('--target-lang', default='CHS', type=str, help='destination language') +parser.add_argument('--use-ctd', action='store_true', help='use comic-text-detector for text detection') parser.add_argument('--verbose', action='store_true', help='print debug info and save intermediate images') args = parser.parse_args() print(args) @@ -59,6 +60,8 @@ def get_task(nonce) : from text_mask import dispatch as dispatch_mask_refinement from textline_merge import dispatch as dispatch_textline_merge from text_rendering import dispatch as dispatch_rendering +from textblockdetector import dispatch as dispatch_ctd_detection +from textblockdetector.textblock import visualize_textblocks async def infer( img, @@ -83,28 +86,46 @@ async def infer( if mode == 'web' and task_id : update_state(task_id, nonce, 'detection') - textlines, mask = await dispatch_detection(img, img_detect_size, args.use_cuda, args, verbose = args.verbose) + + if args.use_ctd: + mask, final_mask, textlines = await dispatch_ctd_detection(img, args.use_cuda) + text_regions = textlines + else: + textlines, mask = await dispatch_detection(img, img_detect_size, args.use_cuda, args, verbose = args.verbose) if args.verbose : - img_bbox_raw = np.copy(img) - for txtln in textlines : - cv2.polylines(img_bbox_raw, [txtln.pts], True, color = (255, 0, 0), thickness = 2) - cv2.imwrite(f'result/{task_id}/bbox_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR)) - cv2.imwrite(f'result/{task_id}/mask_raw.png', mask) + if args.use_ctd: + bboxes = visualize_textblocks(cv2.cvtColor(img,cv2.COLOR_BGR2RGB), textlines) + cv2.imwrite(f'result/{task_id}/bboxes.png', bboxes) + cv2.imwrite(f'result/{task_id}/mask_raw.png', mask) + cv2.imwrite(f'result/{task_id}/mask_final.png', final_mask) + else: + img_bbox_raw = np.copy(img) + for txtln in textlines : + cv2.polylines(img_bbox_raw, [txtln.pts], True, color = (255, 0, 0), thickness = 2) + cv2.imwrite(f'result/{task_id}/bbox_unfiltered.png', cv2.cvtColor(img_bbox_raw, cv2.COLOR_RGB2BGR)) + cv2.imwrite(f'result/{task_id}/mask_raw.png', mask) if mode == 'web' and task_id : update_state(task_id, nonce, 'ocr') textlines = await dispatch_ocr(img, textlines, args.use_cuda, args) - text_regions, textlines = await dispatch_textline_merge(textlines, img.shape[1], img.shape[0], verbose = args.verbose) - if args.verbose : - img_bbox = np.copy(img) - for region in text_regions : - for idx in region.textline_indices : - txtln = textlines[idx] - cv2.polylines(img_bbox, [txtln.pts], True, color = (255, 0, 0), thickness = 2) - img_bbox = cv2.polylines(img_bbox, [region.pts], True, color = (0, 0, 255), thickness = 2) - cv2.imwrite(f'result/{task_id}/bbox.png', cv2.cvtColor(img_bbox, cv2.COLOR_RGB2BGR)) + if not args.use_ctd: + text_regions, textlines = await dispatch_textline_merge(textlines, img.shape[1], img.shape[0], verbose = args.verbose) + if args.verbose : + img_bbox = np.copy(img) + for region in text_regions : + for idx in region.textline_indices : + txtln = textlines[idx] + cv2.polylines(img_bbox, [txtln.pts], True, color = (255, 0, 0), thickness = 2) + img_bbox = cv2.polylines(img_bbox, [region.pts], True, color = (0, 0, 255), thickness = 2) + cv2.imwrite(f'result/{task_id}/bbox.png', cv2.cvtColor(img_bbox, cv2.COLOR_RGB2BGR)) + + print(' -- Generating text mask') + if mode == 'web' and task_id : + update_state(task_id, nonce, 'mask_generation') + # create mask + final_mask = await dispatch_mask_refinement(img, mask, textlines) if mode == 'web' and task_id : print(' -- Translating') @@ -112,11 +133,7 @@ async def infer( # in web mode, we can start translation task async requests.post('http://127.0.0.1:5003/request-translation-internal', json = {'task_id': task_id, 'nonce': nonce, 'texts': [r.text for r in text_regions]}) - print(' -- Generating text mask') - if mode == 'web' and task_id : - update_state(task_id, nonce, 'mask_generation') - # create mask - final_mask = await dispatch_mask_refinement(img, mask, textlines) + print(' -- Running inpainting') if mode == 'web' and task_id : @@ -125,7 +142,7 @@ async def infer( if text_regions : img_inpainted = await dispatch_inpainting(args.use_inpainting, False, args.use_cuda, img, final_mask, args.inpainting_size, verbose = args.verbose) else : - img_inpainted = img + img_inpainted = img, img if args.verbose : img_inpainted, inpaint_input = img_inpainted cv2.imwrite(f'result/{task_id}/inpaint_input.png', cv2.cvtColor(inpaint_input, cv2.COLOR_RGB2BGR)) @@ -133,34 +150,44 @@ async def infer( cv2.imwrite(f'result/{task_id}/mask_final.png', final_mask) # translate text region texts + translated_sentences = None if mode != 'web' : print(' -- Translating') + # try: from translators import dispatch as run_translation - translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.text for r in text_regions]) + if args.use_ctd: + translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.get_text() for r in text_regions]) + else: + translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.text for r in text_regions]) + else : # wait for at most 1 hour - translated_sentences = None for _ in range(36000) : ret = requests.post('http://127.0.0.1:5003/get-translation-result-internal', json = {'task_id': task_id, 'nonce': nonce}).json() if 'result' in ret : translated_sentences = ret['result'] break await asyncio.sleep(0.1) - if not translated_sentences and text_regions : - update_state(task_id, nonce, 'error') - return + # if not translated_sentences and text_regions : + # update_state(task_id, nonce, 'error') + # return - print(' -- Rendering translated text') - if mode == 'web' and task_id : - update_state(task_id, nonce, 'render') - # render translated texts - output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, args.force_horizontal) - - print(' -- Saving results') - if dst_image_name : - cv2.imwrite(dst_image_name, cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) - else : - cv2.imwrite(f'result/{task_id}/final.png', cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) + if translated_sentences is not None: + print(' -- Rendering translated text') + if mode == 'web' and task_id : + update_state(task_id, nonce, 'render') + # render translated texts + if args.use_ctd: + from text_rendering import dispatch_ctd_render + output = await dispatch_ctd_render(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, text_regions, args.force_horizontal) + else: + output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, args.force_horizontal) + + print(' -- Saving results') + if dst_image_name : + cv2.imwrite(dst_image_name, cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) + else : + cv2.imwrite(f'result/{task_id}/final.png', cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) if mode == 'web' and task_id : update_state(task_id, nonce, 'finished') @@ -182,7 +209,11 @@ async def main(mode = 'demo') : with open('alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp : dictionary = [s[:-1] for s in fp.readlines()] load_ocr_model(dictionary, args.use_cuda) - load_detection_model(args.use_cuda) + if args.use_ctd: + from textblockdetector import load_model as load_ctd_model + load_ctd_model(args.use_cuda) + else: + load_detection_model(args.use_cuda) load_inpainting_model(args.use_cuda) if mode == 'demo' : diff --git a/translators/__init__.py b/translators/__init__.py index ae981955e..24cd1dac2 100644 --- a/translators/__init__.py +++ b/translators/__init__.py @@ -115,7 +115,12 @@ GOOGLE_CLIENT = google.Translator() BAIDU_CLIENT = baidu.Translator() YOUDAO_CLIENT = youdao.Translator() -DEEPL_CLIENT = deepl.Translator() +try: + DEEPL_CLIENT = deepl.Translator() +except Exception as e: + DEEPL_CLIENT = GOOGLE_CLIENT + print(f'fail to initialize deepl :\n{str(e)} \nswitch to google translator') + async def dispatch(translator: str, src_lang: str, tgt_lang: str, texts: List[str], *args, **kwargs) -> List[str] : if translator not in ['google', 'youdao', 'baidu', 'deepl', 'null'] : From 5e500f3e4121beb4fa9d91063b2db9ac7ebb3135 Mon Sep 17 00:00:00 2001 From: dmMaze Date: Sun, 23 Jan 2022 16:59:04 +0800 Subject: [PATCH 2/3] fix some bugs, rm cached_property decorator and .structure of TextBlock --- .gitignore | 1 - ocr/__init__.py | 21 +++++++---- text_rendering/__init__.py | 17 ++++----- textblockdetector/textblock.py | 66 +++++++++++++++++++++------------- 4 files changed, 63 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index a07961125..49043ad01 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,5 @@ result *.ckpt *.pt .vscode -translators __pycache__ ocrs \ No newline at end of file diff --git a/ocr/__init__.py b/ocr/__init__.py index 45be753e6..6bf6d7898 100644 --- a/ocr/__init__.py +++ b/ocr/__init__.py @@ -92,15 +92,22 @@ def run_ocr_32px(img: np.ndarray, cuda: bool, quadrilaterals: List[Tuple[Union[Q cur_region = quadrilaterals[indices[i]][0] if isinstance(cur_region, Quadrilateral): cur_region.text = txt + cur_region.prob = prob + cur_region.fg_r = fr + cur_region.fg_g = fg + cur_region.fg_b = fb + cur_region.bg_r = br + cur_region.bg_g = bg + cur_region.bg_b = bb else: cur_region.text.append(txt) - cur_region.prob = prob - cur_region.fg_r = fr - cur_region.fg_g = fg - cur_region.fg_b = fb - cur_region.bg_r = br - cur_region.bg_g = bg - cur_region.bg_b = bb + cur_region.fg_r += fr + cur_region.fg_g += fg + cur_region.fg_b += fb + cur_region.bg_r += br + cur_region.bg_g += bg + cur_region.bg_b += bb + out_regions.append(cur_region) return out_regions diff --git a/text_rendering/__init__.py b/text_rendering/__init__.py index 7f175cc3e..19cddcddf 100644 --- a/text_rendering/__init__.py +++ b/text_rendering/__init__.py @@ -115,16 +115,15 @@ async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translate async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], force_horizontal: bool) -> np.ndarray : for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) : + print(f'text: {region.get_text()} \n trans: {trans_text}') if not trans_text : continue if force_horizontal : majority_dir = 'h' else: majority_dir = 'v' if region.vertical else 'h' - print(region.text) - print(trans_text) - fg = (region.fg_r, region.fg_g, region.fg_b) - bg = (region.bg_r, region.bg_g, region.bg_b) + + fg, bg = region.get_font_colors() font_size = region.font_size font_size = round(font_size) @@ -193,7 +192,7 @@ async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer x, y, w, h = cv2.boundingRect(tmp_mask) r_prime = w / h - r = region.aspect_ratio + r = region.aspect_ratio() if majority_dir != 'v': r = 1 / r @@ -209,12 +208,10 @@ async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer src_pts = np.array([[x - w_ext, y - h_ext], [x + w + w_ext, y - h_ext], [x + w + w_ext, y + h + h_ext], [x - w_ext, y + h + h_ext]]).astype(np.float32) src_pts[:, 0] = np.clip(np.round(src_pts[:, 0]), 0, enlarged_w * 2) src_pts[:, 1] = np.clip(np.round(src_pts[:, 1]), 0, enlarged_h * 2) - # dst_pts = region.mini_rect[:, [3, 0, 1, 2]] + + dst_pts = region.min_rect() if majority_dir == 'v': - dst_pts = region.mini_rect[:, [3, 0, 1, 2]] - else: - dst_pts = region.mini_rect - # dst_pts = region.mini_rect + dst_pts = dst_pts[:, [3, 0, 1, 2]] M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) tmp_rgba = np.concatenate([tmp_canvas, tmp_mask[:, :, None]], axis = -1).astype(np.float32) rgba_region = np.clip(cv2.warpPerspective(tmp_rgba, M, (img_canvas.shape[1], img_canvas.shape[0]), flags = cv2.INTER_LINEAR, borderMode = cv2.BORDER_CONSTANT, borderValue = 0), 0, 255) diff --git a/textblockdetector/textblock.py b/textblockdetector/textblock.py index b4b04ec94..b0cadeddc 100644 --- a/textblockdetector/textblock.py +++ b/textblockdetector/textblock.py @@ -47,13 +47,14 @@ def __init__(self, xyxy: List, self.structure = None self.text = list() - self.prob = None - self.fg_r = None - self.fg_g = None - self.fg_b = None - self.bg_r = None - self.bg_g = None - self.bg_b = None + self.prob = 1 + # note they're accumulative rgb values of textlines + self.fg_r = 0 + self.fg_g = 0 + self.fg_b = 0 + self.bg_r = 0 + self.bg_g = 0 + self.bg_b = 0 def adjust_bbox(self, with_bbox=False): lines = np.array(self.lines) @@ -74,21 +75,19 @@ def sort_lines(self): self.distance = self.distance[idx] lines = np.array(self.lines, dtype=np.int32) self.lines = lines[idx].tolist() - self.structure = self.structure[idx] + # self.structure = self.structure[idx] def lines_array(self, dtype=np.float64): return np.array(self.lines, dtype=dtype) - @functools.cached_property def aspect_ratio(self) -> float: - mini_rect = self.mini_rect - middle_pnts = (mini_rect[:, [1, 2, 3, 0]] + mini_rect) / 2 + min_rect = self.min_rect() + middle_pnts = (min_rect[:, [1, 2, 3, 0]] + min_rect) / 2 norm_v = np.linalg.norm(middle_pnts[:, 2] - middle_pnts[:, 0]) norm_h = np.linalg.norm(middle_pnts[:, 1] - middle_pnts[:, 3]) return norm_v / norm_h - @functools.cached_property - def mini_rect(self): + def min_rect(self): center = [self.xyxy[0]/2, self.xyxy[1]/2] polygons = self.lines_array().reshape(-1, 8) rotated_polygons = rotate_polygons(center, polygons, self.angle) @@ -100,6 +99,15 @@ def mini_rect(self): min_bbox = rotate_polygons(center, min_bbox, -self.angle) return min_bbox.reshape(-1, 4, 2) + def get_font_colors(self): + num_lines = len(self.lines) + if num_lines > 0: + frgb = (np.array([self.fg_r, self.fg_g, self.fg_b]) / num_lines).astype(np.int32) + brgb = (np.array([self.bg_r, self.bg_g, self.bg_b]) / num_lines).astype(np.int32) + return (frgb[0], frgb[1], frgb[2]), (brgb[0], brgb[1], brgb[2]) + else: + return (0, 0, 0), (0, 0, 0) + def __getattribute__(self, name: str): if name == 'pts': return self.lines_array() @@ -122,19 +130,28 @@ def to_dict(self, extra_info=False): return blk_dict def get_transformed_region(self, img, idx, textheight) -> np.ndarray : - [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure[idx]] - v_vec = l2a - l1a - h_vec = l1b - l2b - ratio = np.linalg.norm(v_vec) / np.linalg.norm(h_vec) - src_pts = self.pts[idx].astype(np.float32) + im_h, im_w = img.shape[:2] direction = 'v' if self.vertical else 'h' + src_pts = np.array(self.lines[idx], dtype=np.float64) + + if self.language == 'eng' or (self.language == 'unknown' and not self.vertical): + e_size = self.font_size / 3 + src_pts[..., 0] += np.array([-e_size, e_size, e_size, -e_size]) + src_pts[..., 1] += np.array([-e_size, -e_size, e_size, e_size]) + src_pts[..., 0] = np.clip(src_pts[..., 0], 0, im_w) + src_pts[..., 1] = np.clip(src_pts[..., 1], 0, im_h) + + middle_pnt = (src_pts[[1, 2, 3, 0]] + src_pts) / 2 + vec_v = middle_pnt[2] - middle_pnt[0] # vertical vectors of textlines + vec_h = middle_pnt[1] - middle_pnt[3] # horizontal vectors of textlines + ratio = np.linalg.norm(vec_v) / np.linalg.norm(vec_h) + if direction == 'h' : h = int(textheight) w = int(round(textheight / ratio)) dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32) M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) region = cv2.warpPerspective(img, M, (w, h)) - return region elif direction == 'v' : w = int(textheight) h = int(round(textheight * ratio)) @@ -142,9 +159,9 @@ def get_transformed_region(self, img, idx, textheight) -> np.ndarray : M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) region = cv2.warpPerspective(img, M, (w, h)) region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE) - # cv2.imshow('region'+str(idx), region) - # cv2.waitKey(0) - return region + # cv2.imshow('region'+str(idx), region) + # cv2.waitKey(0) + return region def get_text(self): return ' '.join(self.text) @@ -216,7 +233,7 @@ def examine_textblk(blk: TextBlock, im_w: int, im_h: int, eval_orientation: bool blk.vertical = vertical blk.vec = primary_vec blk.norm = primary_norm - blk.structure = middle_pnts + # blk.structure = middle_pnts if sort: blk.sort_lines() @@ -245,6 +262,7 @@ def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.3, distanc blk.angle = int(round(np.rad2deg(math.atan2(vec_sum[1], vec_sum[0])))) blk.norm = np.linalg.norm(vec_sum) blk.distance = np.append(blk.distance, blk2.distance[-1]) + # blk.structure = np.concatenate((blk.structure, blk2.structure)) blk.font_size = fntsz_avg blk2.merged = True return True @@ -368,7 +386,7 @@ def visualize_textblocks(canvas, blk_list: List[TextBlock]): for jj, line in enumerate(lines): cv2.putText(canvas, str(jj), line[0], cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,127,0), 1) cv2.polylines(canvas, [line], True, (0,127,255), 2) - cv2.polylines(canvas, [blk.mini_rect], True, (127,127,0), 2) + cv2.polylines(canvas, [blk.min_rect()], True, (127,127,0), 2) center = [int((bx1 + bx2)/2), int((by1 + by2)/2)] cv2.putText(canvas, str(blk.angle), center, cv2.FONT_HERSHEY_SIMPLEX, 1, (127,127,255), 2) cv2.putText(canvas, str(ii), (bx1, by1 + lw + 2), 0, lw / 3, (255,127,127), max(lw-1, 1), cv2.LINE_AA) From b3c6d2ec825a069db1b129b3dbc896b8bd6ee314 Mon Sep 17 00:00:00 2001 From: dmMaze Date: Sun, 23 Jan 2022 16:59:56 +0800 Subject: [PATCH 3/3] add textblock breaker for google translator --- translators/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/translators/__init__.py b/translators/__init__.py index 24cd1dac2..fcfad483c 100644 --- a/translators/__init__.py +++ b/translators/__init__.py @@ -137,11 +137,13 @@ async def dispatch(translator: str, src_lang: str, tgt_lang: str, texts: List[st src_lang = LANGUAGE_CODE_MAP[translator][src_lang] if src_lang != 'auto' else 'auto' if tgt_lang == 'NONE' or src_lang == 'NONE' : raise Exception + TEXTBLK_BREAK = '\n###\n' if translator == 'google' : - concat_texts = '\n'.join(texts) + concat_texts = TEXTBLK_BREAK.join(texts) result = await GOOGLE_CLIENT.translate(concat_texts, tgt_lang, src_lang, *args, **kwargs) - if not isinstance(result, list) : - result = result.text.split('\n') + if not isinstance(result, list): + result = result.text.split(TEXTBLK_BREAK.replace('\n', '')) + result = [text.lstrip().rstrip() for text in result] elif translator == 'baidu' : concat_texts = '\n'.join(texts) result = await BAIDU_CLIENT.translate(src_lang, tgt_lang, concat_texts)