diff --git a/app/__init__.py b/app/__init__.py index 49abe67..b0acfd9 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -139,6 +139,7 @@ def create_celery() -> Celery: "app.tasks.ocr", "app.tasks.import_from_labelplus", "app.tasks.thumbnail", + "app.tasks.mit", # only included for completeness's sake. its impl is in other repo. ], related_name=None, ) diff --git a/app/apis/manga_image_translator.py b/app/apis/manga_image_translator.py new file mode 100644 index 0000000..f1fbc18 --- /dev/null +++ b/app/apis/manga_image_translator.py @@ -0,0 +1,87 @@ +# Translation preprocess API backed by manga-image-translator worker +from app.core.views import MoeAPIView +from flask import request + +from app.exceptions.base import UploadFileNotFoundError +from app.tasks.mit import ( + mit_ocr, + mit_detect_text, + mit_translate, + mit_detect_text_default_params, + mit_ocr_default_params, +) +from app.tasks import queue_task, wait_result_sync +from app import app_config +from werkzeug.datastructures import FileStorage +from tempfile import NamedTemporaryFile +import os +from app.utils.logging import logger + +MIT_STORAGE_ROOT = app_config.get("MIT_STORAGE_ROOT", "/MIT_STORAGE_ROOT_UNDEFINED") + + +def _wait_task_result(task_id: str): + try: + result = wait_result_sync(task_id, timeout=1) + return {"task_id": task_id, "status": "success", "result": result} + except TimeoutError: + return { + "task_id": task_id, + "status": "pending", + } + except Exception as e: + return {"task_id": task_id, "status": "fail", "message": str(e)} + + +class MitImageApi(MoeAPIView): + # upload image file for other APIs + def post(self): + logger.info("files: %s", request.files) + blob: None | FileStorage = request.files.get("file") + logger.info("blob: %s", blob) + if not (blob and blob.filename.endswith((".jpg", ".jpeg", ".png", ".gif"))): + raise UploadFileNotFoundError("Please select an image file") + tmpfile = NamedTemporaryFile(dir=MIT_STORAGE_ROOT, delete=False) + tmpfile.write(blob.read()) + tmpfile.close() + return {"filename": tmpfile.name} + + +class MitImageTaskApi(MoeAPIView): + def post(self): + task_params: dict[str, str] = self.get_json() + logger.info("task_params: %s", task_params) + tmpfile_name = task_params.pop("filename", None) + if not tmpfile_name: + raise ValueError("Filename required") + tmpfile_path = os.path.join(MIT_STORAGE_ROOT, tmpfile_name) + if os.path.commonprefix([tmpfile_path, MIT_STORAGE_ROOT]) != MIT_STORAGE_ROOT: + raise ValueError("Invalid filename") + if not os.path.isfile(tmpfile_path): + raise ValueError("File not found") + task_name = task_params.pop("task_name", None) + if task_name == "mit_detect_text": + merged_params = mit_detect_text_default_params.copy() + merged_params.update(task_params) + task_id = queue_task(mit_detect_text, tmpfile_path, **merged_params) + return {"task_id": task_id} + elif task_name == "mit_ocr": + merged_params = mit_ocr_default_params.copy() + merged_params.update(task_params) + task_id = queue_task(mit_ocr, tmpfile_path, **merged_params) + return {"task_id": task_id} + else: + raise ValueError("Invalid task name") + + def get(self, task_id: str): + return _wait_task_result(task_id) + + +class MitTranslateTaskApi(MoeAPIView): + def post(self): + task_params = self.get_json() + task_id = queue_task(mit_translate, **task_params) + return {"task_id": task_id} + + def get(self, task_id: str): + return _wait_task_result(task_id) diff --git a/app/apis/urls.py b/app/apis/urls.py index 0202b8a..249200b 100644 --- a/app/apis/urls.py +++ b/app/apis/urls.py @@ -68,6 +68,12 @@ ) from app.apis.language import LanguageListAPI from app.apis.target import TargetAPI +from app.apis.manga_image_translator import ( + MitImageApi, + MitImageTaskApi, + MitTranslateTaskApi, +) +from app import app_config v1_prefix = "/v1" # api主页 @@ -246,7 +252,7 @@ methods=["GET", "PUT", "DELETE", "OPTIONS"], view_func=ProjectSetAPI.as_view("team_project_set"), ) -# 术语库模块 +# 术语库模块 TODO: not used (yet?) in moeflow v1.1.0 term_bank = Blueprint("term_bank", __name__, url_prefix=v1_prefix + "/term-banks") term_bank.add_url_rule( "/", @@ -457,3 +463,31 @@ methods=["GET", "OPTIONS"], view_func=AdminVCodeListAPI.as_view("admin_v_code_list"), ) + +if app_config["MIT_STORAGE_ROOT"]: + mit = Blueprint("manga-image-translator", __name__, url_prefix=v1_prefix + "/mit") + mit.add_url_rule( + "/images", + methods=["POST", "OPTIONS"], + view_func=MitImageApi.as_view("mit_image_upload"), + ) + mit.add_url_rule( + "/image-tasks", + methods=["POST", "OPTIONS"], + view_func=MitImageTaskApi.as_view("mit_image_tasks_create"), + ) + mit.add_url_rule( + "/image-tasks/", + methods=["GET", "OPTIONS"], + view_func=MitImageTaskApi.as_view("mit_image_tasks_query"), + ) + mit.add_url_rule( + "/translate-tasks", + methods=["POST", "OPTIONS"], + view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_create"), + ) + mit.add_url_rule( + "/translate-tasks/", + methods=["GET", "OPTIONS"], + view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_query"), + ) diff --git a/app/config.py b/app/config.py index 32c0627..54056f3 100644 --- a/app/config.py +++ b/app/config.py @@ -112,8 +112,11 @@ # ----------- # Celery # ----------- -CELERY_BROKER_URL = f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}" -CELERY_BACKEND_URL = DB_URI +CELERY_BROKER_URL = env.get( + "CELERY_BROKER_URL", + f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}", +) +CELERY_BACKEND_URL = env.get("CELERY_BACKEND_URL", DB_URI) CELERY_MONGODB_BACKEND_SETTINGS = { "database": env["MONGODB_DB_NAME"], "taskmeta_collection": "celery_taskmeta", diff --git a/app/models/target.py b/app/models/target.py index a625fea..5cb6355 100644 --- a/app/models/target.py +++ b/app/models/target.py @@ -59,7 +59,7 @@ def create( file.create_target_cache(target) return target - def outputs(self): + def outputs(self) -> list[Output]: """所有导出""" outputs = Output.objects(project=self.project, target=self).order_by( "-create_time" diff --git a/app/tasks/__init__.py b/app/tasks/__init__.py index 789ba54..22577de 100644 --- a/app/tasks/__init__.py +++ b/app/tasks/__init__.py @@ -7,8 +7,43 @@ flower --port=5555 --broker=redis://localhost:6379/1 """ +import asyncio +import datetime +from typing import Any +from celery import Task +from celery.result import AsyncResult +from celery.exceptions import TimeoutError as CeleryTimeoutError +from app import celery as celery_app +from asgiref.sync import async_to_sync + class SyncResult: """和celery的delay异步返回类似的结果,用于同步、异步切换""" task_id = "sync" + + +def queue_task(task: Task, *args, **kwargs) -> str: + result = task.delay(*args, **kwargs) + result.forget() + return result.id + + +def wait_result_sync(task_id: str, timeout: int = 10) -> Any: + result = AsyncResult(id=task_id, app=celery_app) + try: + return result.get(timeout=timeout) + except CeleryTimeoutError: + raise TimeoutError + + +@async_to_sync +async def wait_result(task_id: str, timeout: int = 10) -> Any: + start = datetime.datetime.now().timestamp() + result = AsyncResult(id=task_id, app=celery_app) + while not result.ready(): + if (datetime.datetime.now().timestamp() - start) > timeout: + result.forget() + raise TimeoutError + await asyncio.sleep(0.5e3) + return result.get() # type: ignore diff --git a/app/tasks/mit.py b/app/tasks/mit.py new file mode 100644 index 0000000..1d3801b --- /dev/null +++ b/app/tasks/mit.py @@ -0,0 +1,180 @@ +""" +Text segmentation + OCR using manga-image-translator +""" + +from dataclasses import dataclass +from typing import Any +from app import celery as celery_app +from celery import Task +from celery.result import AsyncResult +import logging + +logger = logging.getLogger(__name__) + + +gpu_options = {"device": "cuda"} + + +@celery_app.task(name="tasks.mit.detect_text") +def _mit_detect_text(path_or_url: str, **kwargs): + pass # Real implementation is in manga_translator/moeflow_worker.py + + +@celery_app.task(name="tasks.mit.ocr") +def _mit_ocr(path_or_url: str, **kwargs): + pass + + +@celery_app.task(name="tasks.mit.translate") +def _mit_translate(**kwargs): + pass + + +@celery_app.task(name="tasks.mit.inpaint") +def _mit_inpaint(path_or_url: str, **kwargs): + pass + + +def to_dict(**kwargs) -> dict[str, Any]: + return kwargs + + +mit_detect_text_default_params = to_dict( + detector_key="default", + # mostly defaults from manga-image-translator/args.py + detect_size=2560, + text_threshold=0.5, + box_threshold=0.7, + unclip_ratio=2.3, + invert=False, + gamma_correct=False, + rotate=False, + verbose=True, +) + + +def _run_mit_detect_text(image_path: str) -> dict: + detect_text: AsyncResult = mit_detect_text.delay( + image_path, + **mit_detect_text_default_params, + **gpu_options, + ) + # XXX unrecommended but should not cause dead lock + result: dict = detect_text.get(disable_sync_subtasks=False) # type: ignore + logger.info("detect_text finished: %s", result) + return result + + +mit_ocr_default_params = to_dict( + ocr_key="48px", # recommended by rowland + # ocr_key="48px_ctc", + # ocr_key="mocr", # XXX: mocr may have different output format + # use_mocr_merge=True, + verbose=True, +) + + +def _run_mit_ocr(image_path: str, regions: list[dict]) -> list[dict]: + ocr: AsyncResult = mit_ocr.delay( + image_path, + regions=regions, + **mit_ocr_default_params, + **gpu_options, + ) + # XXX unrecommended but should not cause dead lock + ocred: list[dict] = ocr.get(disable_sync_subtasks=False) + logger.info("ocr finished: %s", ocred) + for t in ocred: + logger.info("ocr extracted text: %s", t) + return ocred + + +def _run_mit_translate( + text: str, + translator: str = "gpt3.5", + target_lang: str = "CHT", +) -> str: + t: AsyncResult = mit_translate.delay( + query=text, + translator=translator, + target_lang=target_lang, + ) + # XXX unrecommended but should not cause dead lock + + result: str = t.get(disable_sync_subtasks=False) + logger.info("translated %s %s", text, result) + return result[0] + + +@celery_app.task(name="tasks.preprocess_mit") +def _preprocess_mit(image_path: str, target_lang: str): + detected = _run_mit_detect_text(image_path) + ocred = _run_mit_ocr(image_path, detected["textlines"]) + translated_texts = [ + _run_mit_translate(t["text"], target_lang=target_lang) for t in ocred + ] + quads = [ + { + "pts": t["pts"], + "raw_text": t["text"], + "translated": translated_texts[i], + } + for i, t in enumerate(ocred) + ] + return { + "image_path": image_path, + "target_lang": target_lang, + "text_quads": quads, + } + + +@dataclass(frozen=True) +class MitTextQuad: + pts: list[tuple[int, int]] + raw_text: str + translated: str + + def to_dict(self) -> dict[str, Any]: + return { + "pts": self.pts, + "raw_text": self.raw_text, + "translated": self.translated, + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "MitTextQuad": + return cls( + pts=d["pts"], + raw_text=d["raw_text"], + translated=d["translated"], + ) + + +@dataclass(frozen=True) +class MitPreprocessedImage: + image_path: str + target_lang: str + text_quads: list[MitTextQuad] + + def to_dict(self) -> dict[str, Any]: + return { + "image_path": self.image_path, + "target_lang": self.target_lang, + "text_quads": [t.to_dict() for t in self.text_quads], + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "MitPreprocessedImage": + return cls( + image_path=d["image_path"], + target_lang=d["target_lang"], + text_quads=[MitTextQuad.from_dict(t) for t in d["text_quads"]], + ) + + +# export tasks with a better type +mit_detect_text: Task = _mit_detect_text # type: ignore +mit_ocr: Task = _mit_ocr # type: ignore +mit_translate: Task = _mit_translate # type: ignore +mit_inpaint: Task = _mit_inpaint # type: ignore +preprocess_mit: Task = _preprocess_mit # type: ignore diff --git a/manage.py b/manage.py index 0f6b6e3..9d48102 100644 --- a/manage.py +++ b/manage.py @@ -1,7 +1,6 @@ import os - +import re import click - import logging from app import create_app, init_db @@ -91,9 +90,42 @@ def local(action): ) +@click.command("mit_file") +@click.option("--file", help="path to image file") +def mit_preprocess_file(file: str): + from app.tasks.mit import preprocess_mit, MitPreprocessedImage + + proprocessed = preprocess_mit.delay(file, "CHT") + proprocessed_result: dict = proprocessed.get() + + print("proprocessed", proprocessed_result) + print("proprocessed", MitPreprocessedImage.from_dict(proprocessed_result)) + + +@click.command("mit_dir") +@click.option("--dir", help="absolute path to a dir containing image files") +def mit_preprocess_dir(dir: str): + from app.tasks.mit import preprocess_mit, MitPreprocessedImage + + for file in os.listdir(dir): + if not re.match(r".*\.(jpg|png|jpeg)$", file): + continue + full_path = os.path.join(dir, file) + proprocessed = preprocess_mit.delay(full_path, "CHT") + proprocessed_result = MitPreprocessedImage.from_dict(proprocessed.get()) + + print("proprocessed", proprocessed_result) + for q in proprocessed_result.text_quads: + print("text block", q.pts) + print(" ", q.raw_text) + print(" ", q.translated) + + main.add_command(local) main.add_command(docs) main.add_command(migrate) +main.add_command(mit_preprocess_file) +main.add_command(mit_preprocess_dir) if __name__ == "__main__": main()