From ab826fc495844512d50c2e7c28cade8ae2a073d6 Mon Sep 17 00:00:00 2001 From: zyddnys Date: Mon, 24 Jan 2022 18:01:34 -0500 Subject: [PATCH] update for new model --- text_rendering/__init__.py | 12 ++-- translate_demo.py | 78 ++++++++++++------------ ui.html | 30 +++++++++- web_main.py | 120 +++++++++++++------------------------ 4 files changed, 117 insertions(+), 123 deletions(-) diff --git a/text_rendering/__init__.py b/text_rendering/__init__.py index 19cddcddf..2bbdb2df1 100644 --- a/text_rendering/__init__.py +++ b/text_rendering/__init__.py @@ -8,12 +8,12 @@ from . import text_render from textblockdetector.textblock import TextBlock -async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], force_horizontal: bool) -> np.ndarray : +async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], text_direction_overwrite: str) -> np.ndarray : for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) : if not trans_text : continue - if force_horizontal : - region.majority_dir = 'h' + if text_direction_overwrite and text_direction_overwrite in ['h', 'v'] : + region.majority_dir = text_direction_overwrite print(region.text) print(trans_text) #print(region.majority_dir, region.pts) @@ -113,13 +113,13 @@ async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translate return img_canvas -async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], force_horizontal: bool) -> np.ndarray : +async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], text_direction_overwrite: str) -> np.ndarray : for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) : print(f'text: {region.get_text()} \n trans: {trans_text}') if not trans_text : continue - if force_horizontal : - majority_dir = 'h' + if text_direction_overwrite and text_direction_overwrite in ['h', 'v'] : + majority_dir = text_direction_overwrite else: majority_dir = 'v' if region.vertical else 'h' diff --git a/translate_demo.py b/translate_demo.py index 4ebc5e6e3..3a9feb2f0 100755 --- a/translate_demo.py +++ b/translate_demo.py @@ -1,12 +1,21 @@ import asyncio -import torch -import einops import argparse import cv2 import numpy as np import requests +import os from oscrypto import util as crypto_utils +import asyncio + +from detection import dispatch as dispatch_detection, load_model as load_detection_model +from ocr import dispatch as dispatch_ocr, load_model as load_ocr_model +from inpainting import dispatch as dispatch_inpainting, load_model as load_inpainting_model +from text_mask import dispatch as dispatch_mask_refinement +from textline_merge import dispatch as dispatch_textline_merge +from text_rendering import dispatch as dispatch_rendering +from textblockdetector import dispatch as dispatch_ctd_detection +from textblockdetector.textblock import visualize_textblocks parser = argparse.ArgumentParser(description='Generate text bboxes given a image file') parser.add_argument('--mode', default='demo', type=str, help='Run demo in either single image demo mode (demo), web service mode (web) or batch translation mode (batch)') @@ -47,32 +56,25 @@ def update_state(task_id, nonce, state) : def get_task(nonce) : try : rjson = requests.get(f'http://127.0.0.1:5003/task-internal?nonce={nonce}').json() - if 'task_id' in rjson : - return rjson['task_id'] + if 'task_id' in rjson and 'data' in rjson : + return rjson['task_id'], rjson['data'] else : - return None + return None, None except : - return None - -from detection import dispatch as dispatch_detection, load_model as load_detection_model -from ocr import dispatch as dispatch_ocr, load_model as load_ocr_model -from inpainting import dispatch as dispatch_inpainting, load_model as load_inpainting_model -from text_mask import dispatch as dispatch_mask_refinement -from textline_merge import dispatch as dispatch_textline_merge -from text_rendering import dispatch as dispatch_rendering -from textblockdetector import dispatch as dispatch_ctd_detection -from textblockdetector.textblock import visualize_textblocks + return None, None async def infer( img, mode, nonce, + options = None, task_id = '', dst_image_name = '' ) : + options = options or {} img_detect_size = args.size - if task_id and len(task_id) != 32 : - size_ind = task_id[-1] + if 'size' in options : + size_ind = options['size'] if size_ind == 'S' : img_detect_size = 1024 elif size_ind == 'M' : @@ -81,20 +83,29 @@ async def infer( img_detect_size = 2048 elif size_ind == 'X' : img_detect_size = 2560 - print(f' -- Detection size {size_ind}, resolution {img_detect_size}') + print(f' -- Detection resolution {img_detect_size}') + detector = 'ctd' if args.use_ctd else 'default' + if 'detector' in options : + detector = options['detector'] + print(f' -- Detector using {detector}') + render_text_direction_overwrite = 'h' if args.force_horizontal else '' + if 'direction' in options : + if options['direction'] == 'horizontal' : + render_text_direction_overwrite = 'h' + print(f' -- Render text direction is {render_text_direction_overwrite or "auto"}') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if mode == 'web' and task_id : update_state(task_id, nonce, 'detection') - if args.use_ctd: + if detector == 'ctd' : mask, final_mask, textlines = await dispatch_ctd_detection(img, args.use_cuda) text_regions = textlines else: textlines, mask = await dispatch_detection(img, img_detect_size, args.use_cuda, args, verbose = args.verbose) if args.verbose : - if args.use_ctd: + if detector == 'ctd' : bboxes = visualize_textblocks(cv2.cvtColor(img,cv2.COLOR_BGR2RGB), textlines) cv2.imwrite(f'result/{task_id}/bboxes.png', bboxes) cv2.imwrite(f'result/{task_id}/mask_raw.png', mask) @@ -110,7 +121,7 @@ async def infer( update_state(task_id, nonce, 'ocr') textlines = await dispatch_ocr(img, textlines, args.use_cuda, args) - if not args.use_ctd: + if detector == 'default' : text_regions, textlines = await dispatch_textline_merge(textlines, img.shape[1], img.shape[0], verbose = args.verbose) if args.verbose : img_bbox = np.copy(img) @@ -155,7 +166,7 @@ async def infer( print(' -- Translating') # try: from translators import dispatch as run_translation - if args.use_ctd: + if detector == 'ctd' : translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.get_text() for r in text_regions]) else: translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.text for r in text_regions]) @@ -177,11 +188,11 @@ async def infer( if mode == 'web' and task_id : update_state(task_id, nonce, 'render') # render translated texts - if args.use_ctd: + if detector == 'ctd' : from text_rendering import dispatch_ctd_render - output = await dispatch_ctd_render(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, text_regions, args.force_horizontal) + output = await dispatch_ctd_render(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, text_regions, render_text_direction_overwrite) else: - output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, args.force_horizontal) + output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, render_text_direction_overwrite) print(' -- Saving results') if dst_image_name : @@ -192,10 +203,6 @@ async def infer( if mode == 'web' and task_id : update_state(task_id, nonce, 'finished') -from PIL import Image -import time -import asyncio - def replace_prefix(s: str, old: str, new: str) : if s.startswith(old) : s = new + s[len(old):] @@ -203,17 +210,14 @@ def replace_prefix(s: str, old: str, new: str) : async def main(mode = 'demo') : print(' -- Loading models') - import os os.makedirs('result', exist_ok = True) text_render.prepare_renderer() with open('alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp : dictionary = [s[:-1] for s in fp.readlines()] load_ocr_model(dictionary, args.use_cuda) - if args.use_ctd: - from textblockdetector import load_model as load_ctd_model - load_ctd_model(args.use_cuda) - else: - load_detection_model(args.use_cuda) + from textblockdetector import load_model as load_ctd_model + load_ctd_model(args.use_cuda) + load_detection_model(args.use_cuda) load_inpainting_model(args.use_cuda) if mode == 'demo' : @@ -232,12 +236,12 @@ async def main(mode = 'demo') : import sys subprocess.Popen([sys.executable, 'web_main.py', nonce, '5003']) while True : - task_id = get_task(nonce) + task_id, options = get_task(nonce) if task_id : print(f' -- Processing task {task_id}') img = cv2.imread(f'result/{task_id}/input.png') try : - infer_task = asyncio.create_task(infer(img, mode, nonce, task_id)) + infer_task = asyncio.create_task(infer(img, mode, nonce, options, task_id)) asyncio.gather(infer_task) except : import traceback diff --git a/ui.html b/ui.html index 21cddf3f2..02a5b53c5 100644 --- a/ui.html +++ b/ui.html @@ -112,6 +112,8 @@ event.preventDefault(); var detect_res = document.querySelector('input[name="detect-size"]:checked').value; var translator = document.querySelector('input[name="translator-sel"]:checked').value; + var dir = document.querySelector('input[name="dir-sel"]:checked').value; + var detector = document.querySelector('input[name="detector-sel"]:checked').value; var tgt_lang = document.getElementById("target-language").value; console.log(document.getElementById("target-language")); var files = document.getElementById("image-file").files; @@ -135,6 +137,8 @@ formData.append('size', detect_res); formData.append('translator', translator); formData.append('tgt_lang', tgt_lang); + formData.append('dir', dir); + formData.append('detector', detector); XHR.open('POST', BASE_URI + "submit", true); XHR.onload = async function () { if (XHR.status == 200) { @@ -185,7 +189,7 @@

Image/Manga translator

- Detection resolution: + Detection resolution
- Translator: + Text detector + + +
+
+ Translator
+
+ Render text direction + + +
Target language: