Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: assisted translation as addon #7

Merged
merged 5 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create_celery() -> Celery:
"app.tasks.ocr",
"app.tasks.import_from_labelplus",
"app.tasks.thumbnail",
"app.tasks.mit", # only included for completeness's sake. its impl is in other repo.
],
related_name=None,
)
Expand Down
87 changes: 87 additions & 0 deletions app/apis/manga_image_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Translation preprocess API backed by manga-image-translator worker
from app.core.views import MoeAPIView
from flask import request

from app.exceptions.base import UploadFileNotFoundError
from app.tasks.mit import (
mit_ocr,
mit_detect_text,
mit_translate,
mit_detect_text_default_params,
mit_ocr_default_params,
)
from app.tasks import queue_task, wait_result_sync
from app import app_config
from werkzeug.datastructures import FileStorage
from tempfile import NamedTemporaryFile
import os
from app.utils.logging import logger

MIT_STORAGE_ROOT = app_config.get("MIT_STORAGE_ROOT", "/MIT_STORAGE_ROOT_UNDEFINED")


def _wait_task_result(task_id: str):
try:
result = wait_result_sync(task_id, timeout=1)
return {"task_id": task_id, "status": "success", "result": result}
except TimeoutError:
return {
"task_id": task_id,
"status": "pending",
}
except Exception as e:
return {"task_id": task_id, "status": "fail", "message": str(e)}


class MitImageApi(MoeAPIView):
# upload image file for other APIs
def post(self):
logger.info("files: %s", request.files)
blob: None | FileStorage = request.files.get("file")
logger.info("blob: %s", blob)
if not (blob and blob.filename.endswith((".jpg", ".jpeg", ".png", ".gif"))):
raise UploadFileNotFoundError("Please select an image file")
tmpfile = NamedTemporaryFile(dir=MIT_STORAGE_ROOT, delete=False)
tmpfile.write(blob.read())
tmpfile.close()
return {"filename": tmpfile.name}


class MitImageTaskApi(MoeAPIView):
def post(self):
task_params: dict[str, str] = self.get_json()
logger.info("task_params: %s", task_params)
tmpfile_name = task_params.pop("filename", None)
if not tmpfile_name:
raise ValueError("Filename required")
tmpfile_path = os.path.join(MIT_STORAGE_ROOT, tmpfile_name)
if os.path.commonprefix([tmpfile_path, MIT_STORAGE_ROOT]) != MIT_STORAGE_ROOT:
raise ValueError("Invalid filename")
if not os.path.isfile(tmpfile_path):
raise ValueError("File not found")
task_name = task_params.pop("task_name", None)
if task_name == "mit_detect_text":
merged_params = mit_detect_text_default_params.copy()
merged_params.update(task_params)
task_id = queue_task(mit_detect_text, tmpfile_path, **merged_params)
return {"task_id": task_id}
elif task_name == "mit_ocr":
merged_params = mit_ocr_default_params.copy()
merged_params.update(task_params)
task_id = queue_task(mit_ocr, tmpfile_path, **merged_params)
return {"task_id": task_id}
else:
raise ValueError("Invalid task name")

def get(self, task_id: str):
return _wait_task_result(task_id)


class MitTranslateTaskApi(MoeAPIView):
def post(self):
task_params = self.get_json()
task_id = queue_task(mit_translate, **task_params)
return {"task_id": task_id}

def get(self, task_id: str):
return _wait_task_result(task_id)
36 changes: 35 additions & 1 deletion app/apis/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
)
from app.apis.language import LanguageListAPI
from app.apis.target import TargetAPI
from app.apis.manga_image_translator import (
MitImageApi,
MitImageTaskApi,
MitTranslateTaskApi,
)
from app import app_config

v1_prefix = "/v1"
# api主页
Expand Down Expand Up @@ -246,7 +252,7 @@
methods=["GET", "PUT", "DELETE", "OPTIONS"],
view_func=ProjectSetAPI.as_view("team_project_set"),
)
# 术语库模块
# 术语库模块 TODO: not used (yet?) in moeflow v1.1.0
term_bank = Blueprint("term_bank", __name__, url_prefix=v1_prefix + "/term-banks")
term_bank.add_url_rule(
"/<term_bank_id>",
Expand Down Expand Up @@ -457,3 +463,31 @@
methods=["GET", "OPTIONS"],
view_func=AdminVCodeListAPI.as_view("admin_v_code_list"),
)

if app_config["MIT_STORAGE_ROOT"]:
mit = Blueprint("manga-image-translator", __name__, url_prefix=v1_prefix + "/mit")
mit.add_url_rule(
"/images",
methods=["POST", "OPTIONS"],
view_func=MitImageApi.as_view("mit_image_upload"),
)
mit.add_url_rule(
"/image-tasks",
methods=["POST", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks_create"),
)
mit.add_url_rule(
"/image-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks_query"),
)
mit.add_url_rule(
"/translate-tasks",
methods=["POST", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_create"),
)
mit.add_url_rule(
"/translate-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks_query"),
)
7 changes: 5 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@
# -----------
# Celery
# -----------
CELERY_BROKER_URL = f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}"
CELERY_BACKEND_URL = DB_URI
CELERY_BROKER_URL = env.get(
"CELERY_BROKER_URL",
f"amqp://{env['RABBITMQ_USER']}:{env['RABBITMQ_PASS']}@moeflow-rabbitmq:5672/{env['RABBITMQ_VHOST_NAME']}",
)
CELERY_BACKEND_URL = env.get("CELERY_BACKEND_URL", DB_URI)
CELERY_MONGODB_BACKEND_SETTINGS = {
"database": env["MONGODB_DB_NAME"],
"taskmeta_collection": "celery_taskmeta",
Expand Down
2 changes: 1 addition & 1 deletion app/models/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create(
file.create_target_cache(target)
return target

def outputs(self):
def outputs(self) -> list[Output]:
"""所有导出"""
outputs = Output.objects(project=self.project, target=self).order_by(
"-create_time"
Expand Down
35 changes: 35 additions & 0 deletions app/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,43 @@
flower --port=5555 --broker=redis://localhost:6379/1
"""

import asyncio
import datetime
from typing import Any
from celery import Task
from celery.result import AsyncResult
from celery.exceptions import TimeoutError as CeleryTimeoutError
from app import celery as celery_app
from asgiref.sync import async_to_sync


class SyncResult:
"""和celery的delay异步返回类似的结果,用于同步、异步切换"""

task_id = "sync"


def queue_task(task: Task, *args, **kwargs) -> str:
result = task.delay(*args, **kwargs)
result.forget()
return result.id


def wait_result_sync(task_id: str, timeout: int = 10) -> Any:
result = AsyncResult(id=task_id, app=celery_app)
try:
return result.get(timeout=timeout)
except CeleryTimeoutError:
raise TimeoutError


@async_to_sync
async def wait_result(task_id: str, timeout: int = 10) -> Any:
start = datetime.datetime.now().timestamp()
result = AsyncResult(id=task_id, app=celery_app)
while not result.ready():
if (datetime.datetime.now().timestamp() - start) > timeout:
result.forget()
raise TimeoutError
await asyncio.sleep(0.5e3)
return result.get() # type: ignore
180 changes: 180 additions & 0 deletions app/tasks/mit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
Text segmentation + OCR using manga-image-translator
"""

from dataclasses import dataclass
from typing import Any
from app import celery as celery_app
from celery import Task
from celery.result import AsyncResult
import logging

logger = logging.getLogger(__name__)


gpu_options = {"device": "cuda"}


@celery_app.task(name="tasks.mit.detect_text")
def _mit_detect_text(path_or_url: str, **kwargs):
pass # Real implementation is in manga_translator/moeflow_worker.py


@celery_app.task(name="tasks.mit.ocr")
def _mit_ocr(path_or_url: str, **kwargs):
pass


@celery_app.task(name="tasks.mit.translate")
def _mit_translate(**kwargs):
pass


@celery_app.task(name="tasks.mit.inpaint")
def _mit_inpaint(path_or_url: str, **kwargs):
pass


def to_dict(**kwargs) -> dict[str, Any]:
return kwargs


mit_detect_text_default_params = to_dict(
detector_key="default",
# mostly defaults from manga-image-translator/args.py
detect_size=2560,
text_threshold=0.5,
box_threshold=0.7,
unclip_ratio=2.3,
invert=False,
gamma_correct=False,
rotate=False,
verbose=True,
)


def _run_mit_detect_text(image_path: str) -> dict:
detect_text: AsyncResult = mit_detect_text.delay(
image_path,
**mit_detect_text_default_params,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
result: dict = detect_text.get(disable_sync_subtasks=False) # type: ignore
logger.info("detect_text finished: %s", result)
return result


mit_ocr_default_params = to_dict(
ocr_key="48px", # recommended by rowland
# ocr_key="48px_ctc",
# ocr_key="mocr", # XXX: mocr may have different output format
# use_mocr_merge=True,
verbose=True,
)


def _run_mit_ocr(image_path: str, regions: list[dict]) -> list[dict]:
ocr: AsyncResult = mit_ocr.delay(
image_path,
regions=regions,
**mit_ocr_default_params,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
ocred: list[dict] = ocr.get(disable_sync_subtasks=False)
logger.info("ocr finished: %s", ocred)
for t in ocred:
logger.info("ocr extracted text: %s", t)
return ocred


def _run_mit_translate(
text: str,
translator: str = "gpt3.5",
target_lang: str = "CHT",
) -> str:
t: AsyncResult = mit_translate.delay(
query=text,
translator=translator,
target_lang=target_lang,
)
# XXX unrecommended but should not cause dead lock

result: str = t.get(disable_sync_subtasks=False)
logger.info("translated %s %s", text, result)
return result[0]


@celery_app.task(name="tasks.preprocess_mit")
def _preprocess_mit(image_path: str, target_lang: str):
detected = _run_mit_detect_text(image_path)
ocred = _run_mit_ocr(image_path, detected["textlines"])
translated_texts = [
_run_mit_translate(t["text"], target_lang=target_lang) for t in ocred
]
quads = [
{
"pts": t["pts"],
"raw_text": t["text"],
"translated": translated_texts[i],
}
for i, t in enumerate(ocred)
]
return {
"image_path": image_path,
"target_lang": target_lang,
"text_quads": quads,
}


@dataclass(frozen=True)
class MitTextQuad:
pts: list[tuple[int, int]]
raw_text: str
translated: str

def to_dict(self) -> dict[str, Any]:
return {
"pts": self.pts,
"raw_text": self.raw_text,
"translated": self.translated,
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitTextQuad":
return cls(
pts=d["pts"],
raw_text=d["raw_text"],
translated=d["translated"],
)


@dataclass(frozen=True)
class MitPreprocessedImage:
image_path: str
target_lang: str
text_quads: list[MitTextQuad]

def to_dict(self) -> dict[str, Any]:
return {
"image_path": self.image_path,
"target_lang": self.target_lang,
"text_quads": [t.to_dict() for t in self.text_quads],
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitPreprocessedImage":
return cls(
image_path=d["image_path"],
target_lang=d["target_lang"],
text_quads=[MitTextQuad.from_dict(t) for t in d["text_quads"]],
)


# export tasks with a better type
mit_detect_text: Task = _mit_detect_text # type: ignore
mit_ocr: Task = _mit_ocr # type: ignore
mit_translate: Task = _mit_translate # type: ignore
mit_inpaint: Task = _mit_inpaint # type: ignore
preprocess_mit: Task = _preprocess_mit # type: ignore
Loading