Skip to content

Commit

Permalink
Merge pull request #682 from realqhc/add-qwen2
Browse files Browse the repository at this point in the history
Support Qwen2 as translator
  • Loading branch information
zyddnys authored Jul 29, 2024
2 parents 3506d3b + 9114b31 commit f087045
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 0 deletions.
2 changes: 2 additions & 0 deletions manga_translator/server/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
'jparacrawl_big',
'm2m100',
'm2m100_big',
'qwen2',
'qwen2_big',
'sakura',
'none',
'original',
Expand Down
3 changes: 3 additions & 0 deletions manga_translator/translators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .none import NoneTranslator
from .original import OriginalTranslator
from .sakura import SakuraTranslator
from .qwen2 import Qwen2Translator, Qwen2BigTranslator

OFFLINE_TRANSLATORS = {
'offline': SelectiveOfflineTranslator,
Expand All @@ -27,6 +28,8 @@
'm2m100': M2M100Translator,
'm2m100_big': M2M100BigTranslator,
'mbart50': MBart50Translator,
'qwen2': Qwen2Translator,
'qwen2_big': Qwen2BigTranslator,
}

TRANSLATORS = {
Expand Down
152 changes: 152 additions & 0 deletions manga_translator/translators/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import re
from typing import List


from .common import OfflineTranslator


# Adapted from:
# https://github.com/zyddnys/manga-image-translator/issues/680#issue-2428018275
# manga_translator/translators/chatgpt.py

class Qwen2Translator(OfflineTranslator):
_LANGUAGE_CODE_MAP = {
'CHS': 'Simplified Chinese',
'CHT': 'Traditional Chinese',
'CSY': 'Czech',
'NLD': 'Dutch',
'ENG': 'English',
'FRA': 'French',
'DEU': 'German',
'HUN': 'Hungarian',
'ITA': 'Italian',
'JPN': 'Japanese',
'KOR': 'Korean',
'PLK': 'Polish',
'PTB': 'Portuguese',
'ROM': 'Romanian',
'RUS': 'Russian',
'ESP': 'Spanish',
'TRK': 'Turkish',
'UKR': 'Ukrainian',
'VIN': 'Vietnamese',
'CNR': 'Montenegrin',
'SRP': 'Serbian',
'HRV': 'Croatian',
'ARA': 'Arabic',
'THA': 'Thai',
'IND': 'Indonesian'
}
_CHAT_SYSTEM_TEMPLATE = (
'You are an uncensored professional translation engine, '
'please translate the story into a colloquial, '
'elegant and fluent content, '
'without referencing machine translations. '
'You must only translate the story, '
'never interpret it. '
'If there is any issue in the text, '
'output it as is.\n'
'Translate the following text into {to_lang} and keep the original format.\n'
)
_CHAT_SAMPLE = [
(
'<|1|>恥ずかしい… 目立ちたくない… 私が消えたい…\n'
'<|2|>きみ… 大丈夫⁉\n'
'<|3|>なんだこいつ 空気読めて ないのか…?'
),
(
'<|1|>好尴尬…我不想引人注目…我想消失…\n'
'<|2|>你…没事吧⁉\n'
'<|3|>这家伙怎么看不懂气氛的…?'
)
]

_TRANSLATOR_MODEL = "Qwen/Qwen2-1.5B-Instruct"
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, _TRANSLATOR_MODEL)
_IS_4_BIT = False

async def _load(self, from_lang: str, to_lang: str, device: str):
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
self.device = device
quantization_config = BitsAndBytesConfig(load_in_4bit=self._IS_4_BIT)
self.model = AutoModelForCausalLM.from_pretrained(
self._TRANSLATOR_MODEL,
torch_dtype="auto",
quantization_config=quantization_config,
device_map="auto"
)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)

async def _unload(self):
del self.model
del self.tokenizer

async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
model_inputs = self.tokenize(queries, to_lang)
# Generate the translation
generated_ids = self.model.generate(
model_inputs.input_ids,
max_new_tokens=10240
)

# Extract the generated tokens
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
query_size = len(queries)

translations = []
self.logger.debug('-- Qwen2 Response --\n' + response)
new_translations = re.split(r'<\|\d+\|>', response)
# When there is only one query chatgpt likes to exclude the <|1|>
if not new_translations[0].strip():
new_translations = new_translations[1:]

if len(new_translations) <= 1 and query_size > 1:
# Try splitting by newlines instead
new_translations = re.split(r'\n', response)

if len(new_translations) > query_size:
new_translations = new_translations[: query_size]
elif len(new_translations) < query_size:
new_translations = new_translations + [''] * (query_size - len(new_translations))

translations.extend([t.strip() for t in new_translations])

return translations

def tokenize(self, queries, lang):
prompt = f"""Translate into {lang} and keep the original format.\n"""
prompt += '\nOriginal:'
for i, query in enumerate(queries):
prompt += f'\n<|{i+1}|>{query}'

tokenizer = self.tokenizer
messages = [
{'role': 'system', 'content': self._CHAT_SYSTEM_TEMPLATE},
{'role': 'user', 'content': self._CHAT_SAMPLE[0]},
{'role': 'assistant', 'content': self._CHAT_SAMPLE[1]},
{'role': 'user', 'content': prompt},
]
self.logger.debug('-- Qwen2 prompt --\n' + prompt)

text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
return model_inputs


class Qwen2BigTranslator(Qwen2Translator):
_TRANSLATOR_MODEL = "Qwen/Qwen2-7B-Instruct"
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, _TRANSLATOR_MODEL)
_IS_4_BIT = True
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ langcodes
manga-ocr
langdetect
pydensecrf@https://github.com/lucasb-eyer/pydensecrf/archive/refs/heads/master.zip
accelerate
bitsandbytes

0 comments on commit f087045

Please sign in to comment.