Skip to content

Commit

Permalink
provide mit-worker -based apis
Browse files Browse the repository at this point in the history
  • Loading branch information
jokester committed Apr 22, 2024
1 parent ba17fd3 commit 846fec8
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 2 deletions.
1 change: 1 addition & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create_celery() -> Celery:
"app.tasks.ocr",
"app.tasks.import_from_labelplus",
"app.tasks.thumbnail",
"app.tasks.mit", # only included for completeness's sake. its impl is in other repo.
],
related_name=None,
)
Expand Down
76 changes: 76 additions & 0 deletions app/apis/manga_image_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Translation preprocess API backed by manga-image-translator worker
from app.core.views import MoeAPIView
from flask import request

from app.exceptions.base import UploadFileNotFoundError
from app.tasks.mit import preprocess_mit, mit_ocr, mit_detect_text, mit_translate
from app.tasks import queue_task, wait_result_sync
from app import app_config
from werkzeug.datastructures import FileStorage
from tempfile import NamedTemporaryFile
import os

MIT_STORAGE_ROOT = app_config.get("MIT_STORAGE_ROOT", "/MIT_STORAGE_ROOT_UNDEFINED")


def _wait_task_result(task_id: str):
try:
result = wait_result_sync(task_id, timeout=1)
return {"id": task_id, "status": "success", "result": result}
except TimeoutError:
return {
"id": task_id,
"status": "pending",
}
except Exception as e:
return {"id": task_id, "status": "fail", "message": str(e)}


class MitImageApi(MoeAPIView):
# upload image file for other APIs
def post(self):
blob: None | FileStorage = request.files.get("file", None)
if not (blob and blob.name.endswith((".jpg", ".jpeg", ".png", ".gif"))):
raise UploadFileNotFoundError("Please select an image file")
tmpfile = NamedTemporaryFile(dir=MIT_STORAGE_ROOT, delete=False)
tmpfile.write(blob.read())
tmpfile.close()
return {"filename": tmpfile.name}


class MitImageTaskApi(MoeAPIView):
_MIT_IMAGE_TASKS = {
'mit_detect_text': mit_detect_text,
'mit_ocr': mit_ocr,
}

def post(self):
task_params = self.get_json()
task_name = task_params.get("task_name", None)
if task_name not in self._MIT_IMAGE_TASKS:
raise ValueError("Invalid task name")
if 'filename' not in task_params:
raise ValueError("Filename required")
tmpfile_name = task_params['filename']
tmpfile_path = os.path.join(MIT_STORAGE_ROOT, tmpfile_name)
if os.path.commonprefix([tmpfile_path, MIT_STORAGE_ROOT]) != MIT_STORAGE_ROOT:
raise ValueError("Invalid filename")
if not os.path.isfile(tmpfile_path):
raise ValueError("File not found")
task_id = queue_task(self._MIT_IMAGE_TASKS[task_name], tmpfile_path, **task_params)
return {"id": task_id}

def get(self, task_id: str):
return _wait_task_result(task_id)


class MitTranslateTaskApi(MoeAPIView):
def post(self):
task_params = self.get_json()
text: str = task_params.get('text')
target_lang = task_params.get('target_lang', 'CHT')
task_id = queue_task(mit_translate, text, target_lang)
return {'task_id': task_id}

def get(self, task_id: str):
return _wait_task_result(task_id)
32 changes: 31 additions & 1 deletion app/apis/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
)
from app.apis.language import LanguageListAPI
from app.apis.target import TargetAPI
from app.apis.manga_image_translator import MitImageApi, MitImageTaskApi, MitTranslateTaskApi
from app import app_config

v1_prefix = "/v1"
# api主页
Expand Down Expand Up @@ -246,7 +248,7 @@
methods=["GET", "PUT", "DELETE", "OPTIONS"],
view_func=ProjectSetAPI.as_view("team_project_set"),
)
# 术语库模块
# 术语库模块 TODO: not used (yet?) in moeflow v1.1.0
term_bank = Blueprint("term_bank", __name__, url_prefix=v1_prefix + "/term-banks")
term_bank.add_url_rule(
"/<term_bank_id>",
Expand Down Expand Up @@ -457,3 +459,31 @@
methods=["GET", "OPTIONS"],
view_func=AdminVCodeListAPI.as_view("admin_v_code_list"),
)

if "MIT_STORAGE_ROOT" in app_config:
mit = Blueprint("manga-image-translator", __name__, url_prefix=v1_prefix + "/mit")
mit.add_url_rule(
'/images',
methods=['POST', 'OPTIONS'],
view_func=MitImageApi.as_view('mit_images')
)
mit.add_url_rule(
"/image-tasks",
methods=["POST", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks"),
)
mit.add_url_rule(
"/image-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitImageTaskApi.as_view("mit_image_tasks"),
)
mit.add_url_rule(
"/translate-tasks",
methods=["POST", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks"),
)
mit.add_url_rule(
"/translate-tasks/<task_id>",
methods=["GET", "OPTIONS"],
view_func=MitTranslateTaskApi.as_view("mit_translate_tasks"),
)
2 changes: 1 addition & 1 deletion app/models/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create(
file.create_target_cache(target)
return target

def outputs(self):
def outputs(self) -> list[Output]:
"""所有导出"""
outputs = Output.objects(project=self.project, target=self).order_by(
"-create_time"
Expand Down
35 changes: 35 additions & 0 deletions app/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,43 @@
flower --port=5555 --broker=redis://localhost:6379/1
"""

import asyncio
import datetime
from typing import Any
from celery import Task
from celery.result import AsyncResult
from celery.exceptions import TimeoutError as CeleryTimeoutError
from app import celery as celery_app
from asgiref.sync import async_to_sync


class SyncResult:
"""和celery的delay异步返回类似的结果,用于同步、异步切换"""

task_id = "sync"


def queue_task(task: Task, *args, **kwargs) -> str:
result = task.delay(*args, **kwargs)
result.forget()
return result.id


def wait_result_sync(task_id: str, timeout: int = 10) -> Any:
result = AsyncResult(id=task_id, app=celery_app)
try:
return result.get(timeout=timeout)
except CeleryTimeoutError:
raise TimeoutError


@async_to_sync
async def wait_result(task_id: str, timeout: int = 10) -> Any:
start = datetime.datetime.now().timestamp()
result = AsyncResult(id=task_id, app=celery_app)
while not result.ready():
if (datetime.datetime.now().timestamp() - start) > timeout:
result.forget()
raise TimeoutError
await asyncio.sleep(0.5e3)
return result.get() # type: ignore
166 changes: 166 additions & 0 deletions app/tasks/mit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Text segmentation + OCR using manga-image-translator
"""

from dataclasses import dataclass
from typing import Any
from app import celery as celery_app
from celery import Task
from celery.result import AsyncResult
import logging

logger = logging.getLogger(__name__)


gpu_options = {"device": "cuda"}


@celery_app.task(name="tasks.mit.detect_text")
def _mit_detect_text(path_or_url: str, **kwargs):
pass # Real implementation is in manga_translator/moeflow_worker.py


@celery_app.task(name="tasks.mit.ocr")
def _mit_ocr(path_or_url: str, **kwargs):
pass


@celery_app.task(name="tasks.mit.translate")
def _mit_translate(**kwargs):
pass


@celery_app.task(name="tasks.mit.inpaint")
def _mit_inpaint(path_or_url: str, **kwargs):
pass


def _run_mit_detect_text(image_path: str) -> dict:
detect_text: AsyncResult = mit_detect_text.delay(
image_path,
detector_key="default",
# mostly defaults from manga-image-translator/args.py
detect_size=2560,
text_threshold=0.5,
box_threshold=0.7,
unclip_ratio=2.3,
invert=False,
gamma_correct=False,
rotate=False,
verbose=True,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
result: dict = detect_text.get(disable_sync_subtasks=False) # type: ignore
logger.info("detect_text finished: %s", result)
return result


def _run_mit_ocr(image_path: str, regions: list[dict]) -> list[dict]:
ocr: AsyncResult = mit_ocr.delay(
image_path,
ocr_key="48px", # recommended by rowland
# ocr_key="48px_ctc",
# ocr_key="mocr", # XXX: mocr may have different output format
# use_mocr_merge=True,
regions=regions,
verbose=True,
**gpu_options,
)
# XXX unrecommended but should not cause dead lock
ocred: list[dict] = ocr.get(disable_sync_subtasks=False)
logger.info("ocr finished: %s", ocred)
for t in ocred:
logger.info("ocr extracted text: %s", t)
return ocred


def _run_mit_translate(
text: str,
translator: str = "gpt3.5",
target_lang: str = "CHT",
) -> str:
t: AsyncResult = mit_translate.delay(
query=text,
translator=translator,
target_lang=target_lang,
)
# XXX unrecommended but should not cause dead lock

result: str = t.get(disable_sync_subtasks=False)
logger.info("translated %s %s", text, result)
return result[0]


@celery_app.task(name="tasks.preprocess_mit")
def _preprocess_mit(image_path: str, target_lang: str):
detected = _run_mit_detect_text(image_path)
ocred = _run_mit_ocr(image_path, detected["textlines"])
translated_texts = [
_run_mit_translate(t["text"], target_lang=target_lang) for t in ocred
]
quads = [
{
"pts": t["pts"],
"raw_text": t["text"],
"translated": translated_texts[i],
}
for i, t in enumerate(ocred)
]
return {
"image_path": image_path,
"target_lang": target_lang,
"text_quads": quads,
}


@dataclass(frozen=True)
class MitTextQuad:
pts: list[tuple[int, int]]
raw_text: str
translated: str

def to_dict(self) -> dict[str, Any]:
return {
"pts": self.pts,
"raw_text": self.raw_text,
"translated": self.translated,
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitTextQuad":
return cls(
pts=d["pts"],
raw_text=d["raw_text"],
translated=d["translated"],
)


@dataclass(frozen=True)
class MitPreprocessedImage:
image_path: str
target_lang: str
text_quads: list[MitTextQuad]

def to_dict(self) -> dict[str, Any]:
return {
"image_path": self.image_path,
"target_lang": self.target_lang,
"text_quads": [t.to_dict() for t in self.text_quads],
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> "MitPreprocessedImage":
return cls(
image_path=d["image_path"],
target_lang=d["target_lang"],
text_quads=[MitTextQuad.from_dict(t) for t in d["text_quads"]],
)


# export tasks with a better type
mit_detect_text: Task = _mit_detect_text # type: ignore
mit_ocr: Task = _mit_ocr # type: ignore
mit_translate: Task = _mit_translate # type: ignore
mit_inpaint: Task = _mit_inpaint # type: ignore
preprocess_mit: Task = _preprocess_mit # type: ignore
33 changes: 33 additions & 0 deletions manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,42 @@ def local(action):
)


@click.command("mit_file")
@click.option("--file", help="path to image file")
def mit_preprocess_file(file: str):
from app.tasks.mit import preprocess_mit, MitPreprocessedImage

proprocessed = preprocess_mit.delay(file, "CHT")
proprocessed_result: dict = proprocessed.get()

print("proprocessed", proprocessed_result)
print("proprocessed", MitPreprocessedImage.from_dict(proprocessed_result))


@click.command("mit_dir")
@click.option("--dir", help="absolute path to a dir containing image files")
def mit_preprocess_dir(dir: str):
from app.tasks.mit import preprocess_mit, MitPreprocessedImage

for file in os.listdir(dir):
if not re.match(r".*\.(jpg|png|jpeg)$", file):
continue
full_path = os.path.join(dir, file)
proprocessed = preprocess_mit.delay(full_path, "CHT")
proprocessed_result = MitPreprocessedImage.from_dict(proprocessed.get())

print("proprocessed", proprocessed_result)
for q in proprocessed_result.text_quads:
print("text block", q.pts)
print(" ", q.raw_text)
print(" ", q.translated)


main.add_command(local)
main.add_command(docs)
main.add_command(migrate)
main.add_command(mit_preprocess_file)
main.add_command(mit_preprocess_dir)

if __name__ == "__main__":
main()

0 comments on commit 846fec8

Please sign in to comment.