Skip to content

Commit

Permalink
Merge pull request #7 from jokester/experiment/moeflow-companion
Browse files Browse the repository at this point in the history
PoC: assisted translation as addon
  • Loading branch information
jokester authored Jun 1, 2024
2 parents 0d2daf9 + 96d0d3d commit 78c5a3e
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 6 deletions.
1 change: 1 addition & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create_celery() -> Celery:
"app.tasks.ocr",
"app.tasks.import_from_labelplus",
"app.tasks.thumbnail",
"app.tasks.mit", # only included for completeness's sake. its impl is in other repo.
],
related_name=None,
)
Expand Down
87 changes: 87 additions & 0 deletions app/apis/manga_image_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Translation preprocess API backed by manga-image-translator worker
from app.core.views import MoeAPIView
from flask import request

from app.exceptions.base import UploadFileNotFoundError
from app.tasks.mit import (
mit_ocr,
mit_detect_text,
mit_translate,
mit_detect_text_default_params,
mit_ocr_default_params,
)
from app.tasks import queue_task, wait_result_sync
from app import app_config
from werkzeug.datastructures import FileStorage
from tempfile import NamedTemporaryFile
import os
from app.utils.logging import logger

MIT_STORAGE_ROOT = app_config.get("MIT_STORAGE_ROOT", "/MIT_STORAGE_ROOT_UNDEFINED")


def _wait_task_result(task_id: str):
try:
result = wait_result_sync(task_id, timeout=1)
return {"task_id": task_id, "status": "success", "result": result}
except TimeoutError:
return {
"task_id": task_id,
"status": "pending",
}
except Exception as e:
return {"task_id": task_id, "status": "fail", "message": str(e)}


class MitImageApi(MoeAPIView):
# upload image file for other APIs
def post(self):
logger.info("files: %s", request.files)
blob: None | FileStorage = request.files.get("file")
logger.info("blob: %s", blob)
if not (blob and blob.filename.endswith((".jpg", ".jpeg", ".png", ".gif"))):
raise UploadFileNotFoundError("Please select an image file")
tmpfile = NamedTemporaryFile(dir=MIT_STORAGE_ROOT, delete=False)
tmpfile.write(blob.read())
tmpfile.close()
return {"filename": tmpfile.name}


class MitImageTaskApi(MoeAPIView):
def post(self):
task_params: dict[str, str] = self.get_json()
logger.info("task_params: %s", task_params)
tmpfile_name = task_params.pop("filename", None)
if not tmpfile_name:
raise ValueError("Filename required")
tmpfile_path = os.path.join(MIT_STORAGE_ROOT, tmpfile_name)
if os.path.commonprefix([tmpfile_path, MIT_STORAGE_ROOT]) != MIT_STORAGE_ROOT:
raise ValueError("Invalid filename")
if not os.path.isfile(tmpfile_path):
raise ValueError("File not found")
task_name = task_params.pop("task_name", None)
if task_name == "mit_detect_text":
merged_params = mit_detect_text_default_params.copy()
merged_params.update(task_params)
task_id = queue_task(mit_detect_text, tmpfile_path, **merged_params)
return {"task_id": task_id}
elif task_name == "mit_ocr":
merged_params = mit_ocr_default_params.copy()
merged_params.update(task_params)
task_id = queue_task(mit_ocr, tmpfile_path, **merged_params)
return {"task_id": task_id}
else:
raise ValueError("Invalid task name")

def get(self, task_id: str):
return _wait_task_result(task_id)


class MitTranslateTaskApi(MoeAPIView):
def post(self):
task_params = self.get_json()
task_id = queue_task(mit_translate, **task_params)
return {"task_id": task_id}

def get(self, task_id: str):
return _wait_task_result(task_id)
36 changes: 35 additions & 1 deletion app/apis/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
)
from app.apis.language import LanguageListAPI
from app.apis.target import TargetAPI
from app.apis.manga_image_translator import (
MitImageApi,
MitImageTaskApi,
MitTranslateTaskApi,
)
from app import app_config

v1_prefix = "/v1"
# api主页
Expand Down Expand Up @@ -246,7 +252,7 @@
methods=["GET", "PUT", "DELETE", "OPTIONS"],
view_func=ProjectSetAPI.as_view("team_project_set"),
)
# 术语库模块
# 术语库模块 TODO: not used (yet?) in moeflow v1.1.0
term_bank = Blueprint("term_bank", __name__, url_prefix=v1_prefix + "/term-banks")
term_bank.add_url_rule(
"/<term_bank_id>",
Expand Down Expand Up @@ -457,3 +463,31 @@
methods=["GET", "OPTIONS"],
view_func=AdminVCodeListAPI.as_view("admin_v_code_list"),
)

if app_config["MIT_STORAGE_ROOT"]:
mit = Blueprint("manga-image-translator", __name__, url_prefix=v1_prefix + "/mit")
mit.add_url_rule(
"/images",
methods=["POST", "OPTIONS"],
view_func=MitImageApi.as_view("mit_image_upload"),
)
mit.add_url_rule(
"/image-tasks",
methods=["POST", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks_create"),
)
mit.add_url_rule(
"/image-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks_query"),
)
mit.add_url_rule(
"/translate-tasks",
methods=["POST", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_create"),
)
mit.add_url_rule(
"/translate-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_query"),
)
7 changes: 5 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@
# -----------
# Celery
# -----------
CELERY_BROKER_URL = f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}"
CELERY_BACKEND_URL = DB_URI
CELERY_BROKER_URL = env.get(
"CELERY_BROKER_URL",
f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}",
)
CELERY_BACKEND_URL = env.get("CELERY_BACKEND_URL", DB_URI)
CELERY_MONGODB_BACKEND_SETTINGS = {
"database": env["MONGODB_DB_NAME"],
"taskmeta_collection": "celery_taskmeta",
Expand Down
2 changes: 1 addition & 1 deletion app/models/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create(
file.create_target_cache(target)
return target

def outputs(self):
def outputs(self) -> list[Output]:
"""所有导出"""
outputs = Output.objects(project=self.project, target=self).order_by(
"-create_time"
Expand Down
35 changes: 35 additions & 0 deletions app/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,43 @@
flower --port=5555 --broker=redis://localhost:6379/1
"""

import asyncio
import datetime
from typing import Any
from celery import Task
from celery.result import AsyncResult
from celery.exceptions import TimeoutError as CeleryTimeoutError
from app import celery as celery_app
from asgiref.sync import async_to_sync


class SyncResult:
"""和celery的delay异步返回类似的结果,用于同步、异步切换"""

task_id = "sync"


def queue_task(task: Task, *args, **kwargs) -> str:
result = task.delay(*args, **kwargs)
result.forget()
return result.id


def wait_result_sync(task_id: str, timeout: int = 10) -> Any:
result = AsyncResult(id=task_id, app=celery_app)
try:
return result.get(timeout=timeout)
except CeleryTimeoutError:
raise TimeoutError


@async_to_sync
async def wait_result(task_id: str, timeout: int = 10) -> Any:
start = datetime.datetime.now().timestamp()
result = AsyncResult(id=task_id, app=celery_app)
while not result.ready():
if (datetime.datetime.now().timestamp() - start) > timeout:
result.forget()
raise TimeoutError
await asyncio.sleep(0.5e3)
return result.get() # type: ignore
180 changes: 180 additions & 0 deletions app/tasks/mit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Text segmentation + OCR using manga-image-translator
"""

from dataclasses import dataclass
from typing import Any
from app import celery as celery_app
from celery import Task
from celery.result import AsyncResult
import logging

logger = logging.getLogger(__name__)


gpu_options = {"device": "cuda"}


@celery_app.task(name="tasks.mit.detect_text")
def _mit_detect_text(path_or_url: str, **kwargs):
pass # Real implementation is in manga_translator/moeflow_worker.py


@celery_app.task(name="tasks.mit.ocr")
def _mit_ocr(path_or_url: str, **kwargs):
pass


@celery_app.task(name="tasks.mit.translate")
def _mit_translate(**kwargs):
pass


@celery_app.task(name="tasks.mit.inpaint")
def _mit_inpaint(path_or_url: str, **kwargs):
pass


def to_dict(**kwargs) -> dict[str, Any]:
return kwargs


mit_detect_text_default_params = to_dict(
detector_key="default",
# mostly defaults from manga-image-translator/args.py
detect_size=2560,
text_threshold=0.5,
box_threshold=0.7,
unclip_ratio=2.3,
invert=False,
gamma_correct=False,
rotate=False,
verbose=True,
)


def _run_mit_detect_text(image_path: str) -> dict:
detect_text: AsyncResult = mit_detect_text.delay(
image_path,
**mit_detect_text_default_params,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
result: dict = detect_text.get(disable_sync_subtasks=False) # type: ignore
logger.info("detect_text finished: %s", result)
return result


mit_ocr_default_params = to_dict(
ocr_key="48px", # recommended by rowland
# ocr_key="48px_ctc",
# ocr_key="mocr", # XXX: mocr may have different output format
# use_mocr_merge=True,
verbose=True,
)


def _run_mit_ocr(image_path: str, regions: list[dict]) -> list[dict]:
ocr: AsyncResult = mit_ocr.delay(
image_path,
regions=regions,
**mit_ocr_default_params,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
ocred: list[dict] = ocr.get(disable_sync_subtasks=False)
logger.info("ocr finished: %s", ocred)
for t in ocred:
logger.info("ocr extracted text: %s", t)
return ocred


def _run_mit_translate(
text: str,
translator: str = "gpt3.5",
target_lang: str = "CHT",
) -> str:
t: AsyncResult = mit_translate.delay(
query=text,
translator=translator,
target_lang=target_lang,
)
# XXX unrecommended but should not cause dead lock

result: str = t.get(disable_sync_subtasks=False)
logger.info("translated %s %s", text, result)
return result[0]


@celery_app.task(name="tasks.preprocess_mit")
def _preprocess_mit(image_path: str, target_lang: str):
detected = _run_mit_detect_text(image_path)
ocred = _run_mit_ocr(image_path, detected["textlines"])
translated_texts = [
_run_mit_translate(t["text"], target_lang=target_lang) for t in ocred
]
quads = [
{
"pts": t["pts"],
"raw_text": t["text"],
"translated": translated_texts[i],
}
for i, t in enumerate(ocred)
]
return {
"image_path": image_path,
"target_lang": target_lang,
"text_quads": quads,
}


@dataclass(frozen=True)
class MitTextQuad:
pts: list[tuple[int, int]]
raw_text: str
translated: str

def to_dict(self) -> dict[str, Any]:
return {
"pts": self.pts,
"raw_text": self.raw_text,
"translated": self.translated,
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitTextQuad":
return cls(
pts=d["pts"],
raw_text=d["raw_text"],
translated=d["translated"],
)


@dataclass(frozen=True)
class MitPreprocessedImage:
image_path: str
target_lang: str
text_quads: list[MitTextQuad]

def to_dict(self) -> dict[str, Any]:
return {
"image_path": self.image_path,
"target_lang": self.target_lang,
"text_quads": [t.to_dict() for t in self.text_quads],
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitPreprocessedImage":
return cls(
image_path=d["image_path"],
target_lang=d["target_lang"],
text_quads=[MitTextQuad.from_dict(t) for t in d["text_quads"]],
)


# export tasks with a better type
mit_detect_text: Task = _mit_detect_text # type: ignore
mit_ocr: Task = _mit_ocr # type: ignore
mit_translate: Task = _mit_translate # type: ignore
mit_inpaint: Task = _mit_inpaint # type: ignore
preprocess_mit: Task = _preprocess_mit # type: ignore
Loading

0 comments on commit 78c5a3e

Please sign in to comment.