From ce19686e99608d5923599e2a3ea13545dce6deaf Mon Sep 17 00:00:00 2001 From: Michele Yin Date: Fri, 15 Dec 2023 11:20:21 +0100 Subject: [PATCH 1/5] added support for mbart50 and move of nllb and mbart models to the models/translators directory. --- manga_translator/translators/__init__.py | 2 + manga_translator/translators/mbart50.py | 124 +++++++++++++++++++++++ manga_translator/translators/nllb.py | 9 +- 3 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 manga_translator/translators/mbart50.py diff --git a/manga_translator/translators/__init__.py b/manga_translator/translators/__init__.py index 783a91b10..eb7255d43 100644 --- a/manga_translator/translators/__init__.py +++ b/manga_translator/translators/__init__.py @@ -11,6 +11,7 @@ from .nllb import NLLBTranslator, NLLBBigTranslator from .sugoi import JparacrawlTranslator, JparacrawlBigTranslator, SugoiTranslator from .m2m100 import M2M100Translator, M2M100BigTranslator +from .mbart50 import MBart50Translator from .selective import SelectiveOfflineTranslator, prepare as prepare_selective_translator from .none import NoneTranslator from .original import OriginalTranslator @@ -24,6 +25,7 @@ 'jparacrawl_big': JparacrawlBigTranslator, 'm2m100': M2M100Translator, 'm2m100_big': M2M100BigTranslator, + 'mbart50': MBart50Translator, } TRANSLATORS = { diff --git a/manga_translator/translators/mbart50.py b/manga_translator/translators/mbart50.py new file mode 100644 index 000000000..18bb0a8a2 --- /dev/null +++ b/manga_translator/translators/mbart50.py @@ -0,0 +1,124 @@ +import os +import py3langid as langid + + +from .common import OfflineTranslator + +ISO_639_1_TO_MBart50 = { + 'ara': 'ar_AR', + 'deu': 'de_DE', + 'eng': 'en_XX', + 'spa': 'es_XX', + 'fra': 'fr_XX', + 'hin': 'hi_IN', + 'ita': 'it_IT', + 'jpn': 'ja_XX', + 'nld': 'nl_XX', + 'pol': 'pl_PL', + 'por': 'pt_XX', + 'rus': 'ru_RU', + 'swa': 'sw_KE', + 'tha': 'th_TH', + 'tur': 'tr_TR', + 'urd': 'ur_PK', + 'vie': 'vi_VN', + 'zho': 'zh_CN', + +} + +class MBart50Translator(OfflineTranslator): + # https://huggingface.co/facebook/mbart-large-50 + # other languages can be added as well + _LANGUAGE_CODE_MAP = { + "ARA": "ar_AR", + "DEU": "de_DE", + "ENG": "en_XX", + "ESP": "es_XX", + "FRA": "fr_XX", + "HIN": "hi_IN", + "ITA": "it_IT", + "JPN": "ja_XX", + "NLD": "nl_XX", + "PLK": "pl_PL", + "PTB": "pt_XX", + "RUS": "ru_RU", + "SWA": "sw_KE", + "THA": "th_TH", + "TRK": "tr_TR", + "URD": "ur_PK", + "VIN": "vi_VN", + "CHS": "zh_CN", + } + + _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'mbart50') + + _TRANSLATOR_MODEL = "facebook/mbart-large-50-many-to-many-mmt" + + + + async def _load(self, from_lang: str, to_lang: str, device: str): + from transformers import ( + MBartForConditionalGeneration, + AutoTokenizer, + ) + if ':' not in device: + device += ':0' + self.device = device + self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL) + self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL) + + async def _unload(self): + del self.model + del self.tokenizer + + async def _infer(self, from_lang: str, to_lang: str, queries: list[str]) -> list[str]: + if from_lang == 'auto': + detected_lang = langid.classify('\n'.join(queries))[0] + target_lang = self._map_detected_lang_to_translator(detected_lang) + + if target_lang == None: + self.logger.warn('Could not detect language from over all sentence. Will try per sentence.') + else: + from_lang = target_lang + + return [self._translate_sentence(from_lang, to_lang, query) for query in queries] + + def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: + from transformers import pipeline + + if not self.is_loaded(): + return '' + + if from_lang == 'auto': + detected_lang = langid.classify(query)[0] + from_lang = self._map_detected_lang_to_translator(detected_lang) + + if from_lang == None: + self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})') + return '' + + translator = pipeline('translation', + device=self.device, + model=self.model, + tokenizer=self.tokenizer, + src_lang=from_lang, + tgt_lang=to_lang, + max_length = 512, + ) + + result = translator(query)[0]['translation_text'] + return result + + def _map_detected_lang_to_translator(self, lang): + if lang not in ISO_639_1_TO_MBart50: + return None + + return ISO_639_1_TO_MBart50[lang] + + async def _download(self): + import huggingface_hub + huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR) + + def _check_downloaded(self) -> bool: + import huggingface_hub + return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None diff --git a/manga_translator/translators/nllb.py b/manga_translator/translators/nllb.py index 4a8ffb7b7..fb75b35bc 100644 --- a/manga_translator/translators/nllb.py +++ b/manga_translator/translators/nllb.py @@ -1,3 +1,4 @@ +import os from typing import List import py3langid as langid @@ -57,6 +58,7 @@ class NLLBTranslator(OfflineTranslator): 'THA': 'tha_Thai', 'IND': 'ind_Latn' } + _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb') _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M' async def _load(self, from_lang: str, to_lang: str, device: str): @@ -118,11 +120,12 @@ def _map_detected_lang_to_translator(self, lang): async def _download(self): import huggingface_hub - huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL) + huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR) def _check_downloaded(self) -> bool: import huggingface_hub - return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin') is not None + return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None class NLLBBigTranslator(NLLBTranslator): - _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B' + _MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big') + _TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B' \ No newline at end of file From dd63bdc6882beefd319c3821a0fc7bd21d99f87d Mon Sep 17 00:00:00 2001 From: Michele Yin Date: Fri, 15 Dec 2023 14:29:43 +0100 Subject: [PATCH 2/5] filter to remove unused files and save space --- manga_translator/translators/mbart50.py | 68 +++++++++++++------------ manga_translator/translators/nllb.py | 4 +- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/manga_translator/translators/mbart50.py b/manga_translator/translators/mbart50.py index 18bb0a8a2..2e6a9d897 100644 --- a/manga_translator/translators/mbart50.py +++ b/manga_translator/translators/mbart50.py @@ -5,24 +5,26 @@ from .common import OfflineTranslator ISO_639_1_TO_MBart50 = { - 'ara': 'ar_AR', - 'deu': 'de_DE', - 'eng': 'en_XX', - 'spa': 'es_XX', - 'fra': 'fr_XX', - 'hin': 'hi_IN', - 'ita': 'it_IT', - 'jpn': 'ja_XX', - 'nld': 'nl_XX', - 'pol': 'pl_PL', - 'por': 'pt_XX', - 'rus': 'ru_RU', - 'swa': 'sw_KE', - 'tha': 'th_TH', - 'tur': 'tr_TR', - 'urd': 'ur_PK', - 'vie': 'vi_VN', - 'zho': 'zh_CN', + + 'ar': 'ar_AR', + 'de': 'de_DE', + 'en': 'en_XX', + 'es': 'es_XX', + 'fr': 'fr_XX', + 'hi': 'hi_IN', + 'it': 'it_IT', + 'ja': 'ja_XX', + 'nl': 'nl_XX', + 'pl': 'pl_PL', + 'pt': 'pt_XX', + 'ru': 'ru_RU', + 'sw': 'sw_KE', + 'th': 'th_TH', + 'tr': 'tr_TR', + 'ur': 'ur_PK', + 'vi': 'vi_VN', + 'zh': 'zh_CN', + } @@ -65,6 +67,8 @@ async def _load(self, from_lang: str, to_lang: str, device: str): device += ':0' self.device = device self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL) + if self.device != 'cpu': + self.model.to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL) async def _unload(self): @@ -84,7 +88,6 @@ async def _infer(self, from_lang: str, to_lang: str, queries: list[str]) -> list return [self._translate_sentence(from_lang, to_lang, query) for query in queries] def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: - from transformers import pipeline if not self.is_loaded(): return '' @@ -92,21 +95,19 @@ def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: if from_lang == 'auto': detected_lang = langid.classify(query)[0] from_lang = self._map_detected_lang_to_translator(detected_lang) - + else: + from_lang = self._LANGUAGE_CODE_MAP(from_lang) if from_lang == None: self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})') return '' - - translator = pipeline('translation', - device=self.device, - model=self.model, - tokenizer=self.tokenizer, - src_lang=from_lang, - tgt_lang=to_lang, - max_length = 512, - ) - - result = translator(query)[0]['translation_text'] + + self.tokenizer.src_lang = from_lang + tokens = self.tokenizer(query, return_tensors="pt") + # move to device + if self.device != 'cpu': + tokens = tokens.to(self.device) + generated_tokens = self.model.generate(**tokens, forced_bos_token_id=self.tokenizer.lang_code_to_id[to_lang]) + result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] return result def _map_detected_lang_to_translator(self, lang): @@ -117,8 +118,9 @@ def _map_detected_lang_to_translator(self, lang): async def _download(self): import huggingface_hub - huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR) + # do not download msgpack and h5 files as they are not needed to run the model + huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"]) def _check_downloaded(self) -> bool: import huggingface_hub - return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None + return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None \ No newline at end of file diff --git a/manga_translator/translators/nllb.py b/manga_translator/translators/nllb.py index fb75b35bc..547257eba 100644 --- a/manga_translator/translators/nllb.py +++ b/manga_translator/translators/nllb.py @@ -120,7 +120,9 @@ def _map_detected_lang_to_translator(self, lang): async def _download(self): import huggingface_hub - huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR) + # do not download msgpack and h5 files as they are not needed to run the model + huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"]) + def _check_downloaded(self) -> bool: import huggingface_hub From e015cc1841048f284ccb9553745f33a1d7ebff5b Mon Sep 17 00:00:00 2001 From: Michele Yin Date: Fri, 15 Dec 2023 14:32:13 +0100 Subject: [PATCH 3/5] readme fix --- README.md | 1 + README_CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index bf866a6c2..88de5820a 100644 --- a/README.md +++ b/README.md @@ -267,6 +267,7 @@ Limitations: | sugoi | | ✔️ | Sugoi V4.0 Models (recommended for JPN->ENG) | | m2m100 | | ✔️ | Supports every language | | m2m100_big | | ✔️ | | +| mbart50 | | ✔️ | | | none | | ✔️ | Translate to empty texts | | original | | ✔️ | Keep original texts | diff --git a/README_CN.md b/README_CN.md index 09fdec341..e5c0504bf 100644 --- a/README_CN.md +++ b/README_CN.md @@ -71,6 +71,7 @@ $ pip install git+https://github.com/kodalli/pydensecrf.git | m2m100 | | ✔️ | 可以翻译所有语言 | | m2m100_big | | ✔️ | 带big的是完整尺寸,不带是精简版 | | none | | ✔️ | 翻译成空白文本 | +| mbart50 | | ✔️ | | | original | | ✔️ | 翻译成源文本 | ### 语言代码列表 From 0246d68e609ec6d896c2648c59bc90624266cdd2 Mon Sep 17 00:00:00 2001 From: Michele Yin Date: Fri, 15 Dec 2023 15:02:54 +0100 Subject: [PATCH 4/5] forgot the eval --- manga_translator/translators/mbart50.py | 1 + 1 file changed, 1 insertion(+) diff --git a/manga_translator/translators/mbart50.py b/manga_translator/translators/mbart50.py index 2e6a9d897..79d8656d5 100644 --- a/manga_translator/translators/mbart50.py +++ b/manga_translator/translators/mbart50.py @@ -69,6 +69,7 @@ async def _load(self, from_lang: str, to_lang: str, device: str): self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL) if self.device != 'cpu': self.model.to(self.device) + self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL) async def _unload(self): From da058a5bde3dad344186a59cc2c00340fbfc8a97 Mon Sep 17 00:00:00 2001 From: Michele Yin Date: Sat, 16 Dec 2023 15:13:47 +0100 Subject: [PATCH 5/5] typo fix for key error --- manga_translator/translators/mbart50.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/manga_translator/translators/mbart50.py b/manga_translator/translators/mbart50.py index 79d8656d5..d41221e84 100644 --- a/manga_translator/translators/mbart50.py +++ b/manga_translator/translators/mbart50.py @@ -96,8 +96,7 @@ def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: if from_lang == 'auto': detected_lang = langid.classify(query)[0] from_lang = self._map_detected_lang_to_translator(detected_lang) - else: - from_lang = self._LANGUAGE_CODE_MAP(from_lang) + if from_lang == None: self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})') return ''