Skip to content

Commit

Permalink
Merge pull request #541 from BigEmperor26/main
Browse files Browse the repository at this point in the history
MBART 50 Support
  • Loading branch information
zyddnys authored Mar 22, 2024
2 parents 3cd9383 + 7a02bb0 commit 84a4194
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 3 deletions.
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ $ pip install git+https://github.com/kodalli/pydensecrf.git
| m2m100 | | ✔️ | 可以翻译所有语言 |
| m2m100_big | | ✔️ | 带big的是完整尺寸,不带是精简版 |
| none | | ✔️ | 翻译成空白文本 |
| mbart50 | | ✔️ | |
| original | | ✔️ | 翻译成源文本 |

### 语言代码列表
Expand Down
2 changes: 2 additions & 0 deletions manga_translator/translators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .nllb import NLLBTranslator, NLLBBigTranslator
from .sugoi import JparacrawlTranslator, JparacrawlBigTranslator, SugoiTranslator
from .m2m100 import M2M100Translator, M2M100BigTranslator
from .mbart50 import MBart50Translator
from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator
from .none import NoneTranslator
from .original import OriginalTranslator
Expand All @@ -25,6 +26,7 @@
'jparacrawl_big': JparacrawlBigTranslator,
'm2m100': M2M100Translator,
'm2m100_big': M2M100BigTranslator,
'mbart50': MBart50Translator,
}

TRANSLATORS = {
Expand Down
126 changes: 126 additions & 0 deletions manga_translator/translators/mbart50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import py3langid as langid


from .common import OfflineTranslator

ISO_639_1_TO_MBart50 = {

'ar': 'ar_AR',
'de': 'de_DE',
'en': 'en_XX',
'es': 'es_XX',
'fr': 'fr_XX',
'hi': 'hi_IN',
'it': 'it_IT',
'ja': 'ja_XX',
'nl': 'nl_XX',
'pl': 'pl_PL',
'pt': 'pt_XX',
'ru': 'ru_RU',
'sw': 'sw_KE',
'th': 'th_TH',
'tr': 'tr_TR',
'ur': 'ur_PK',
'vi': 'vi_VN',
'zh': 'zh_CN',


}

class MBart50Translator(OfflineTranslator):
# https://huggingface.co/facebook/mbart-large-50
# other languages can be added as well
_LANGUAGE_CODE_MAP = {
"ARA": "ar_AR",
"DEU": "de_DE",
"ENG": "en_XX",
"ESP": "es_XX",
"FRA": "fr_XX",
"HIN": "hi_IN",
"ITA": "it_IT",
"JPN": "ja_XX",
"NLD": "nl_XX",
"PLK": "pl_PL",
"PTB": "pt_XX",
"RUS": "ru_RU",
"SWA": "sw_KE",
"THA": "th_TH",
"TRK": "tr_TR",
"URD": "ur_PK",
"VIN": "vi_VN",
"CHS": "zh_CN",
}

_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'mbart50')

_TRANSLATOR_MODEL = "facebook/mbart-large-50-many-to-many-mmt"



async def _load(self, from_lang: str, to_lang: str, device: str):
from transformers import (
MBartForConditionalGeneration,
AutoTokenizer,
)
if ':' not in device:
device += ':0'
self.device = device
self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL)
if self.device != 'cpu':
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)

async def _unload(self):
del self.model
del self.tokenizer

async def _infer(self, from_lang: str, to_lang: str, queries: list[str]) -> list[str]:
if from_lang == 'auto':
detected_lang = langid.classify('\n'.join(queries))[0]
target_lang = self._map_detected_lang_to_translator(detected_lang)

if target_lang == None:
self.logger.warn('Could not detect language from over all sentence. Will try per sentence.')
else:
from_lang = target_lang

return [self._translate_sentence(from_lang, to_lang, query) for query in queries]

def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str:

if not self.is_loaded():
return ''

if from_lang == 'auto':
detected_lang = langid.classify(query)[0]
from_lang = self._map_detected_lang_to_translator(detected_lang)

if from_lang == None:
self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})')
return ''

self.tokenizer.src_lang = from_lang
tokens = self.tokenizer(query, return_tensors="pt")
# move to device
if self.device != 'cpu':
tokens = tokens.to(self.device)
generated_tokens = self.model.generate(**tokens, forced_bos_token_id=self.tokenizer.lang_code_to_id[to_lang])
result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return result

def _map_detected_lang_to_translator(self, lang):
if lang not in ISO_639_1_TO_MBart50:
return None

return ISO_639_1_TO_MBart50[lang]

async def _download(self):
import huggingface_hub
# do not download msgpack and h5 files as they are not needed to run the model
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])

def _check_downloaded(self) -> bool:
import huggingface_hub
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None
11 changes: 8 additions & 3 deletions manga_translator/translators/nllb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List
import py3langid as langid

Expand Down Expand Up @@ -57,6 +58,7 @@ class NLLBTranslator(OfflineTranslator):
'THA': 'tha_Thai',
'IND': 'ind_Latn'
}
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb')
_TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M'

async def _load(self, from_lang: str, to_lang: str, device: str):
Expand Down Expand Up @@ -118,11 +120,14 @@ def _map_detected_lang_to_translator(self, lang):

async def _download(self):
import huggingface_hub
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL)
# do not download msgpack and h5 files as they are not needed to run the model
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])


def _check_downloaded(self) -> bool:
import huggingface_hub
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin') is not None
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None

class NLLBBigTranslator(NLLBTranslator):
_TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B'
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big')
_TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B'

0 comments on commit 84a4194

Please sign in to comment.