-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from jokester/experiment/moeflow-companion
PoC: assisted translation as addon
- Loading branch information
Showing
8 changed files
with
378 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Translation preprocess API backed by manga-image-translator worker | ||
from app.core.views import MoeAPIView | ||
from flask import request | ||
|
||
from app.exceptions.base import UploadFileNotFoundError | ||
from app.tasks.mit import ( | ||
mit_ocr, | ||
mit_detect_text, | ||
mit_translate, | ||
mit_detect_text_default_params, | ||
mit_ocr_default_params, | ||
) | ||
from app.tasks import queue_task, wait_result_sync | ||
from app import app_config | ||
from werkzeug.datastructures import FileStorage | ||
from tempfile import NamedTemporaryFile | ||
import os | ||
from app.utils.logging import logger | ||
|
||
MIT_STORAGE_ROOT = app_config.get("MIT_STORAGE_ROOT", "/MIT_STORAGE_ROOT_UNDEFINED") | ||
|
||
|
||
def _wait_task_result(task_id: str): | ||
try: | ||
result = wait_result_sync(task_id, timeout=1) | ||
return {"task_id": task_id, "status": "success", "result": result} | ||
except TimeoutError: | ||
return { | ||
"task_id": task_id, | ||
"status": "pending", | ||
} | ||
except Exception as e: | ||
return {"task_id": task_id, "status": "fail", "message": str(e)} | ||
|
||
|
||
class MitImageApi(MoeAPIView): | ||
# upload image file for other APIs | ||
def post(self): | ||
logger.info("files: %s", request.files) | ||
blob: None | FileStorage = request.files.get("file") | ||
logger.info("blob: %s", blob) | ||
if not (blob and blob.filename.endswith((".jpg", ".jpeg", ".png", ".gif"))): | ||
raise UploadFileNotFoundError("Please select an image file") | ||
tmpfile = NamedTemporaryFile(dir=MIT_STORAGE_ROOT, delete=False) | ||
tmpfile.write(blob.read()) | ||
tmpfile.close() | ||
return {"filename": tmpfile.name} | ||
|
||
|
||
class MitImageTaskApi(MoeAPIView): | ||
def post(self): | ||
task_params: dict[str, str] = self.get_json() | ||
logger.info("task_params: %s", task_params) | ||
tmpfile_name = task_params.pop("filename", None) | ||
if not tmpfile_name: | ||
raise ValueError("Filename required") | ||
tmpfile_path = os.path.join(MIT_STORAGE_ROOT, tmpfile_name) | ||
if os.path.commonprefix([tmpfile_path, MIT_STORAGE_ROOT]) != MIT_STORAGE_ROOT: | ||
raise ValueError("Invalid filename") | ||
if not os.path.isfile(tmpfile_path): | ||
raise ValueError("File not found") | ||
task_name = task_params.pop("task_name", None) | ||
if task_name == "mit_detect_text": | ||
merged_params = mit_detect_text_default_params.copy() | ||
merged_params.update(task_params) | ||
task_id = queue_task(mit_detect_text, tmpfile_path, **merged_params) | ||
return {"task_id": task_id} | ||
elif task_name == "mit_ocr": | ||
merged_params = mit_ocr_default_params.copy() | ||
merged_params.update(task_params) | ||
task_id = queue_task(mit_ocr, tmpfile_path, **merged_params) | ||
return {"task_id": task_id} | ||
else: | ||
raise ValueError("Invalid task name") | ||
|
||
def get(self, task_id: str): | ||
return _wait_task_result(task_id) | ||
|
||
|
||
class MitTranslateTaskApi(MoeAPIView): | ||
def post(self): | ||
task_params = self.get_json() | ||
task_id = queue_task(mit_translate, **task_params) | ||
return {"task_id": task_id} | ||
|
||
def get(self, task_id: str): | ||
return _wait_task_result(task_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
""" | ||
Text segmentation + OCR using manga-image-translator | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any | ||
from app import celery as celery_app | ||
from celery import Task | ||
from celery.result import AsyncResult | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
gpu_options = {"device": "cuda"} | ||
|
||
|
||
@celery_app.task(name="tasks.mit.detect_text") | ||
def _mit_detect_text(path_or_url: str, **kwargs): | ||
pass # Real implementation is in manga_translator/moeflow_worker.py | ||
|
||
|
||
@celery_app.task(name="tasks.mit.ocr") | ||
def _mit_ocr(path_or_url: str, **kwargs): | ||
pass | ||
|
||
|
||
@celery_app.task(name="tasks.mit.translate") | ||
def _mit_translate(**kwargs): | ||
pass | ||
|
||
|
||
@celery_app.task(name="tasks.mit.inpaint") | ||
def _mit_inpaint(path_or_url: str, **kwargs): | ||
pass | ||
|
||
|
||
def to_dict(**kwargs) -> dict[str, Any]: | ||
return kwargs | ||
|
||
|
||
mit_detect_text_default_params = to_dict( | ||
detector_key="default", | ||
# mostly defaults from manga-image-translator/args.py | ||
detect_size=2560, | ||
text_threshold=0.5, | ||
box_threshold=0.7, | ||
unclip_ratio=2.3, | ||
invert=False, | ||
gamma_correct=False, | ||
rotate=False, | ||
verbose=True, | ||
) | ||
|
||
|
||
def _run_mit_detect_text(image_path: str) -> dict: | ||
detect_text: AsyncResult = mit_detect_text.delay( | ||
image_path, | ||
**mit_detect_text_default_params, | ||
**gpu_options, | ||
) | ||
# XXX unrecommended but should not cause dead lock | ||
result: dict = detect_text.get(disable_sync_subtasks=False) # type: ignore | ||
logger.info("detect_text finished: %s", result) | ||
return result | ||
|
||
|
||
mit_ocr_default_params = to_dict( | ||
ocr_key="48px", # recommended by rowland | ||
# ocr_key="48px_ctc", | ||
# ocr_key="mocr", # XXX: mocr may have different output format | ||
# use_mocr_merge=True, | ||
verbose=True, | ||
) | ||
|
||
|
||
def _run_mit_ocr(image_path: str, regions: list[dict]) -> list[dict]: | ||
ocr: AsyncResult = mit_ocr.delay( | ||
image_path, | ||
regions=regions, | ||
**mit_ocr_default_params, | ||
**gpu_options, | ||
) | ||
# XXX unrecommended but should not cause dead lock | ||
ocred: list[dict] = ocr.get(disable_sync_subtasks=False) | ||
logger.info("ocr finished: %s", ocred) | ||
for t in ocred: | ||
logger.info("ocr extracted text: %s", t) | ||
return ocred | ||
|
||
|
||
def _run_mit_translate( | ||
text: str, | ||
translator: str = "gpt3.5", | ||
target_lang: str = "CHT", | ||
) -> str: | ||
t: AsyncResult = mit_translate.delay( | ||
query=text, | ||
translator=translator, | ||
target_lang=target_lang, | ||
) | ||
# XXX unrecommended but should not cause dead lock | ||
|
||
result: str = t.get(disable_sync_subtasks=False) | ||
logger.info("translated %s %s", text, result) | ||
return result[0] | ||
|
||
|
||
@celery_app.task(name="tasks.preprocess_mit") | ||
def _preprocess_mit(image_path: str, target_lang: str): | ||
detected = _run_mit_detect_text(image_path) | ||
ocred = _run_mit_ocr(image_path, detected["textlines"]) | ||
translated_texts = [ | ||
_run_mit_translate(t["text"], target_lang=target_lang) for t in ocred | ||
] | ||
quads = [ | ||
{ | ||
"pts": t["pts"], | ||
"raw_text": t["text"], | ||
"translated": translated_texts[i], | ||
} | ||
for i, t in enumerate(ocred) | ||
] | ||
return { | ||
"image_path": image_path, | ||
"target_lang": target_lang, | ||
"text_quads": quads, | ||
} | ||
|
||
|
||
@dataclass(frozen=True) | ||
class MitTextQuad: | ||
pts: list[tuple[int, int]] | ||
raw_text: str | ||
translated: str | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
return { | ||
"pts": self.pts, | ||
"raw_text": self.raw_text, | ||
"translated": self.translated, | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls, d: dict[str, Any]) -> "MitTextQuad": | ||
return cls( | ||
pts=d["pts"], | ||
raw_text=d["raw_text"], | ||
translated=d["translated"], | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class MitPreprocessedImage: | ||
image_path: str | ||
target_lang: str | ||
text_quads: list[MitTextQuad] | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
return { | ||
"image_path": self.image_path, | ||
"target_lang": self.target_lang, | ||
"text_quads": [t.to_dict() for t in self.text_quads], | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls, d: dict[str, Any]) -> "MitPreprocessedImage": | ||
return cls( | ||
image_path=d["image_path"], | ||
target_lang=d["target_lang"], | ||
text_quads=[MitTextQuad.from_dict(t) for t in d["text_quads"]], | ||
) | ||
|
||
|
||
# export tasks with a better type | ||
mit_detect_text: Task = _mit_detect_text # type: ignore | ||
mit_ocr: Task = _mit_ocr # type: ignore | ||
mit_translate: Task = _mit_translate # type: ignore | ||
mit_inpaint: Task = _mit_inpaint # type: ignore | ||
preprocess_mit: Task = _preprocess_mit # type: ignore |
Oops, something went wrong.