forked from zyddnys/manga-image-translator
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main' into moeflow-companion-main
- Loading branch information
Showing
8 changed files
with
231 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import asyncio | ||
import json | ||
import pickle | ||
import requests | ||
from PIL import Image | ||
|
||
async def execute_method(method_name, attributes): | ||
url = f"http://127.0.0.1:5003/execute/{method_name}" | ||
headers = {'Content-Type': 'application/octet-stream'} | ||
|
||
response = requests.post(url, data=pickle.dumps(attributes), headers=headers, stream=True) | ||
|
||
if response.status_code == 200: | ||
buffer = b'' | ||
for chunk in response.iter_content(chunk_size=None): | ||
if chunk: | ||
buffer += chunk | ||
while True: | ||
if len(buffer) >= 5: | ||
status = int.from_bytes(buffer[0:1], byteorder='big') | ||
expected_size = int.from_bytes(buffer[1:5], byteorder='big') | ||
if len(buffer) >= 5 + expected_size: | ||
data = buffer[5:5 + expected_size] | ||
if status == 0: | ||
print("data", pickle.loads(data)) | ||
elif status == 1: | ||
print("log", data) | ||
elif status == 2: | ||
print("error", data) | ||
buffer = buffer[5 + expected_size:] | ||
else: | ||
break | ||
else: | ||
break | ||
else: | ||
print(json.loads(response.content)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
image = Image.open("../imgs/232264684-5a7bcf8e-707b-4925-86b0-4212382f1680.png") | ||
attributes = {"image": image, "params": {"translator": "none", "inpainter": "none"}} | ||
asyncio.run(execute_method("translate", attributes)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import asyncio | ||
import pickle | ||
from threading import Lock | ||
|
||
import uvicorn | ||
from fastapi import FastAPI, HTTPException, Path, Request, Response | ||
from pydantic import BaseModel | ||
import inspect | ||
|
||
from starlette.responses import StreamingResponse | ||
|
||
from manga_translator import MangaTranslator | ||
|
||
class MethodCall(BaseModel): | ||
method_name: str | ||
attributes: bytes | ||
|
||
|
||
async def load_data(request: Request, method): | ||
attributes_bytes = await request.body() | ||
attributes = pickle.loads(attributes_bytes) | ||
sig = inspect.signature(method) | ||
expected_args = set(sig.parameters.keys()) | ||
provided_args = set(attributes.keys()) | ||
|
||
if expected_args != provided_args: | ||
raise HTTPException(status_code=400, detail="Incorrect number or names of arguments") | ||
return attributes | ||
|
||
|
||
class MangaShare: | ||
def __init__(self, params: dict = None): | ||
self.manga = MangaTranslator(params) | ||
self.host = params.get('host', '127.0.0.1') | ||
self.port = int(params.get('port', '5003')) | ||
self.nonce = params.get('nonce', None) | ||
|
||
# each chunk has a structure like this status_code(int/1byte),len(int/4bytes),bytechunk | ||
# status codes are 0 for result, 1 for progress report, 2 for error | ||
self.progress_queue = asyncio.Queue() | ||
self.lock = Lock() | ||
|
||
async def hook(state: str, finished: bool): | ||
state_data = state.encode("utf-8") | ||
progress_data = b'\x01' + len(state_data).to_bytes(4, 'big') + state_data | ||
await self.progress_queue.put(progress_data) | ||
await asyncio.sleep(0) | ||
|
||
self.manga.add_progress_hook(hook) | ||
|
||
async def progress_stream(self): | ||
""" | ||
loops until the status is != 1 which is eiter an error or the result | ||
""" | ||
while True: | ||
progress = await self.progress_queue.get() | ||
yield progress | ||
if progress[0] != 1: | ||
break | ||
|
||
async def run_method(self, method, **attributes): | ||
try: | ||
if asyncio.iscoroutinefunction(method): | ||
result = await method(**attributes) | ||
else: | ||
result = method(**attributes) | ||
result_bytes = pickle.dumps(result) | ||
encoded_result = b'\x00' + len(result_bytes).to_bytes(4, 'big') + result_bytes | ||
await self.progress_queue.put(encoded_result) | ||
except Exception as e: | ||
err_bytes = str(e).encode("utf-8") | ||
encoded_result = b'\x02' + len(err_bytes).to_bytes(4, 'big') + err_bytes | ||
await self.progress_queue.put(encoded_result) | ||
finally: | ||
self.lock.release() | ||
|
||
|
||
def check_nonce(self, request: Request): | ||
if self.nonce: | ||
nonce = request.headers.get('X-Nonce') | ||
if nonce != self.nonce: | ||
raise HTTPException(401, detail="Nonce does not match") | ||
|
||
def check_lock(self): | ||
if not self.lock.acquire(blocking=False): | ||
raise HTTPException(status_code=429, detail="some Method is already being executed.") | ||
|
||
def get_fn(self, method_name: str): | ||
if method_name.startswith("__"): | ||
raise HTTPException(status_code=403, detail="These functions are not allowed to be executed remotely") | ||
method = getattr(self.manga, method_name, None) | ||
if not method: | ||
raise HTTPException(status_code=404, detail="Method not found") | ||
return method | ||
|
||
async def listen(self, translation_params: dict = None): | ||
app = FastAPI() | ||
|
||
@app.get("/is_locked") | ||
async def is_locked(): | ||
if self.lock.locked(): | ||
return {"locked": True} | ||
return {"locked": False} | ||
|
||
@app.post("/simple_execute/{method_name}") | ||
async def execute_method(request: Request, method_name: str = Path(...)): | ||
self.check_nonce(request) | ||
self.check_lock() | ||
method = self.get_fn(method_name) | ||
attr = await load_data(request, method) | ||
try: | ||
if asyncio.iscoroutinefunction(method): | ||
result = await method(**attr) | ||
else: | ||
result = method(**attr) | ||
self.lock.release() | ||
result_bytes = pickle.dumps(result) | ||
return Response(content=result_bytes, media_type="application/octet-stream") | ||
except Exception as e: | ||
self.lock.release() | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.post("/execute/{method_name}") | ||
async def execute_method(request: Request, method_name: str = Path(...)): | ||
self.check_nonce(request) | ||
self.check_lock() | ||
method = self.get_fn(method_name) | ||
attr = await load_data(request, method) | ||
|
||
# streaming response | ||
streaming_response = StreamingResponse(self.progress_stream(), media_type="application/octet-stream") | ||
asyncio.create_task(self.run_method(method, **attr)) | ||
return streaming_response | ||
|
||
config = uvicorn.Config(app, host=self.host, port=self.port) | ||
server = uvicorn.Server(config) | ||
await server.serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters