From 8f9d069f050fc6f6680fe8f15bc859deaad34bce Mon Sep 17 00:00:00 2001 From: Sergey Chernyaev Date: Thu, 25 Jan 2024 23:08:45 +0100 Subject: [PATCH] Add subtitles translation using EasyNMT and OpusMT libraries --- .github/workflows/pylint.yml | 1 + .github/workflows/setup.yml | 23 ++++ README.md | 12 +- auto_subtitle/cli.py | 7 +- auto_subtitle/main.py | 78 ++++++++--- auto_subtitle/translation/__init__.py | 0 auto_subtitle/translation/easynmt_utils.py | 24 ++++ auto_subtitle/translation/languages.py | 20 +++ auto_subtitle/translation/opusmt_utils.py | 149 +++++++++++++++++++++ auto_subtitle/utils/ffmpeg.py | 8 +- auto_subtitle/utils/files.py | 2 + auto_subtitle/utils/mytempfile.py | 10 +- auto_subtitle/utils/whisper.py | 8 +- requirements.txt | 7 +- setup.py | 7 +- 15 files changed, 324 insertions(+), 32 deletions(-) create mode 100644 .github/workflows/setup.yml create mode 100644 auto_subtitle/translation/__init__.py create mode 100644 auto_subtitle/translation/easynmt_utils.py create mode 100644 auto_subtitle/translation/languages.py create mode 100644 auto_subtitle/translation/opusmt_utils.py diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 1c93443..44a5757 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -18,6 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pylint + pip install wheel pip install -r requirements.txt - name: Analysing the code with pylint run: | diff --git a/.github/workflows/setup.yml b/.github/workflows/setup.yml new file mode 100644 index 0000000..60619ea --- /dev/null +++ b/.github/workflows/setup.yml @@ -0,0 +1,23 @@ +name: Setup + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install application + run: | + pip install wheel + pip install -e . + - name: Check that package was installed successfully + run: | + faster_auto_subtitle -h diff --git a/README.md b/README.md index 5a765e9..da8bf92 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,9 @@ This repository uses `ffmpeg` and [OpenAI's Whisper](https://openai.com/blog/whi ## Installation -To get started, you'll need Python 3.7 or newer. Install the binary by running the following command: +To get started, you'll need Python 3.9 or newer. Install the binary by running the following command: + + pip install wheel pip install git+https://github.com/Sirozha1337/faster-auto-subtitle.git@dev @@ -37,6 +39,12 @@ Adding `--task translate` will translate the subtitles into English: faster_auto_subtitle /path/to/video.mp4 --task translate +Adding `--target_language {2-letter-language-code}` will translate the subtitles into specified language using [Opus-MT](https://github.com/Helsinki-NLP/Opus-MT): + + faster_auto_subtitle /path/to/video.mp4 --target_language fr + +This will require downloading the appropriate model. If direct translation is not available it will attempt translation from source to english and from english to source. + Run the following to view all available options: faster_auto_subtitle --help @@ -49,7 +57,7 @@ Higher `beam_size` usually leads to greater accuracy, but slows down the process Setting higher `no_speech_threshold` could be useful for videos with a lot of background noise to stop Whisper from "hallucinating" subtitles for it. -In my experience settings option `condition_on_previous_text` to `False` dramatically increases accurracy for videos like TV Shows with an intro song at the start. +In my experience settings option `condition_on_previous_text` to `False` dramatically increases accurracy for videos like TV Shows with an intro song at the start. You can use `sample_interval` parameter to generate subtitles for a portion of the video to play around with those parameters: diff --git a/auto_subtitle/cli.py b/auto_subtitle/cli.py index 6e030f5..bca8743 100644 --- a/auto_subtitle/cli.py +++ b/auto_subtitle/cli.py @@ -46,11 +46,16 @@ def main(): parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') \ - or X->English translation ('translate')") + or X->Language translation ('translate')") parser.add_argument("--language", type=str, default="auto", choices=LANGUAGE_CODES, help="What is the origin language of the video? \ If unset, it is detected automatically.") + parser.add_argument("--target_language", type=str, default="en", + choices=LANGUAGE_CODES, + help="Desired language to translate subtitles to. \ + If language is not en, Opus-MT will be used. \ + See https://github.com/Helsinki-NLP/Opus-MT.") args = parser.parse_args().__dict__ diff --git a/auto_subtitle/main.py b/auto_subtitle/main.py index cad112f..83b0c6e 100644 --- a/auto_subtitle/main.py +++ b/auto_subtitle/main.py @@ -1,9 +1,9 @@ import os import warnings -import tempfile from .utils.files import filename, write_srt from .utils.ffmpeg import get_audio, overlay_subtitles from .utils.whisper import WhisperAI +from .translation.easynmt_utils import EasyNMTWrapper def process(args: dict): @@ -12,7 +12,8 @@ def process(args: dict): output_srt: bool = args.pop("output_srt") srt_only: bool = args.pop("srt_only") language: str = args.pop("language") - sample_interval: str = args.pop("sample_interval") + sample_interval: list = args.pop("sample_interval") + target_language: str = args.pop("target_language") os.makedirs(output_dir, exist_ok=True) @@ -20,20 +21,36 @@ def process(args: dict): warnings.warn( f"{model_name} is an English-only model, forcing English detection.") args["language"] = "en" + language = "en" # if translate task used and language argument is set, then use it elif language != "auto": args["language"] = language + if target_language != 'en': + warnings.warn( + f"{target_language} is not English, Opus-MT will be used to perform translation.") + args['task'] = 'transcribe' + audios = get_audio(args.pop("video"), args.pop( 'audio_channel'), sample_interval) - model_args = {} - model_args["model_size_or_path"] = model_name - model_args["device"] = args.pop("device") - model_args["compute_type"] = args.pop("compute_type") + model_args = { + "model_size_or_path": model_name, + "device": args.pop("device"), + "compute_type": args.pop("compute_type") + } + + subtitles = get_subtitles(audios, model_args, args) + print('Subtitles generated.') + + if target_language != 'en': + print('Translating subtitles... This might take a while.') + subtitles = translate_subtitles( + subtitles, language, target_language, model_args) - srt_output_dir = output_dir if output_srt or srt_only else tempfile.gettempdir() - subtitles = get_subtitles(audios, srt_output_dir, model_args, args) + if output_srt or srt_only: + print('Saving subtitle files...') + save_subtitles(subtitles, output_dir) if srt_only: return @@ -41,23 +58,48 @@ def process(args: dict): overlay_subtitles(subtitles, output_dir, sample_interval) -def get_subtitles(audio_paths: list, output_dir: str, - model_args: dict, transcribe_args: dict): +def translate_subtitles(subtitles: dict, source_lang: str, target_lang: str, model_args: dict): + model = EasyNMTWrapper(device=model_args['device']) + + translated_subtitles = {} + for key, subtitle in subtitles.items(): + src_lang = source_lang + if src_lang == '' or src_lang is None: + src_lang = subtitle['language'] + + translated_segments = model.translate( + subtitle['segments'], src_lang, target_lang) + + translated_subtitle = subtitle.copy() + translated_subtitle['segments'] = translated_segments + translated_subtitles[key] = translated_subtitle + + return translated_subtitles + + +def save_subtitles(subtitles: dict, output_dir: str): + for path, subtitle in subtitles.items(): + subtitle["output_path"] = os.path.join( + output_dir, f"{filename(path)}.srt") + + print(f'Saving to path {subtitle["output_path"]}') + with open(subtitle['output_path'], "w", encoding="utf-8") as srt: + write_srt(subtitle['segments'], file=srt) + + +def get_subtitles(audio_paths: dict, model_args: dict, transcribe_args: dict): model = WhisperAI(model_args, transcribe_args) - subtitles_path = {} + subtitles = {} for path, audio_path in audio_paths.items(): print( f"Generating subtitles for {filename(path)}... This might take a while." ) - srt_path = os.path.join(output_dir, f"{filename(path)}.srt") - - segments = model.transcribe(audio_path) - with open(srt_path, "w", encoding="utf-8") as srt: - write_srt(segments, file=srt) + segments, info = model.transcribe(audio_path) - subtitles_path[path] = srt_path + subtitles[path] = {'segments': list( + segments), 'language': info.language} - return subtitles_path + return subtitles diff --git a/auto_subtitle/translation/__init__.py b/auto_subtitle/translation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/auto_subtitle/translation/easynmt_utils.py b/auto_subtitle/translation/easynmt_utils.py new file mode 100644 index 0000000..83bce1a --- /dev/null +++ b/auto_subtitle/translation/easynmt_utils.py @@ -0,0 +1,24 @@ +from easynmt import EasyNMT +from faster_whisper.transcribe import Segment +from .opusmt_utils import OpusMT + + +class EasyNMTWrapper: + def __init__(self, device): + self.translator = OpusMT() + self.model = EasyNMT('opus-mt', + translator=self.translator, + device=device if device != 'auto' else None) + + def translate(self, segments: list[Segment], source_lang: str, target_lang: str): + source_text = [segment.text for segment in segments] + self.translator.load_available_models() + + translated_text = self.model.translate(source_text, target_lang, + source_lang, show_progress_bar=True) + translated_segments = [None] * len(segments) + for index, segment in enumerate(segments): + translated_segments[index] = segment._replace( + text=translated_text[index]) + + return translated_segments diff --git a/auto_subtitle/translation/languages.py b/auto_subtitle/translation/languages.py new file mode 100644 index 0000000..1e4eec5 --- /dev/null +++ b/auto_subtitle/translation/languages.py @@ -0,0 +1,20 @@ +import langcodes +from transformers.models.marian.convert_marian_tatoeba_to_pytorch import GROUP_MEMBERS + + +def to_alpha2_languages(languages): + return set(item for sublist in [__to_alpha2_language(language) for language in languages] for item in sublist) + + +def __to_alpha2_language(language): + if len(language) == 2: + return [language] + + if language in GROUP_MEMBERS: + return set([langcodes.Language.get(x).language for x in GROUP_MEMBERS[language][1]]) + + return [langcodes.Language.get(language).language] + + +def to_alpha3_language(language): + return langcodes.Language.get(language).to_alpha3() diff --git a/auto_subtitle/translation/opusmt_utils.py b/auto_subtitle/translation/opusmt_utils.py new file mode 100644 index 0000000..2dd905f --- /dev/null +++ b/auto_subtitle/translation/opusmt_utils.py @@ -0,0 +1,149 @@ +import time +import logging +from typing import List +import torch +from huggingface_hub import list_models, ModelFilter +from transformers import MarianMTModel, MarianTokenizer +from .languages import to_alpha2_languages, to_alpha3_language + +logger = logging.getLogger(__name__) + +NLP_ROOT = 'Helsinki-NLP' + + +class OpusMT: + def __init__(self, max_loaded_models: int = 10): + self.models = {} + self.max_loaded_models = max_loaded_models + self.max_length = None + + self.available_models = None + self.translations_graph = None + + def load_model(self, model_name): + if model_name in self.models: + self.models[model_name]['last_loaded'] = time.time() + return self.models[model_name]['tokenizer'], self.models[model_name]['model'] + + logger.info("Load model: {}" % model_name) + tokenizer = MarianTokenizer.from_pretrained(model_name) + model = MarianMTModel.from_pretrained(model_name) + model.eval() + + if len(self.models) >= self.max_loaded_models: + oldest_time = time.time() + oldest_model = None + for loaded_model_name in self.models.keys(): + if self.models[loaded_model_name]['last_loaded'] <= oldest_time: + oldest_model = loaded_model_name + oldest_time = self.models[loaded_model_name]['last_loaded'] + del self.models[oldest_model] + + self.models[model_name] = { + 'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()} + return tokenizer, model + + def load_available_models(self): + if self.available_models is not None: + return + + print('Loading a list of available language models from OPUS-NT') + model_list = list_models( + filter=ModelFilter( + author=NLP_ROOT + ) + ) + + suffix = [x.modelId.split("/")[1] for x in model_list + if x.modelId.startswith(f'{NLP_ROOT}/opus-mt') and 'tc' not in x.modelId] + + models = [DownloadableModel(f"{NLP_ROOT}/{s}") + for s in suffix if s == s.lower()] + + self.available_models = {} + for model in models: + for src in model.source_languages: + for tgt in model.target_languages: + key = f'{src}-{tgt}' + if key not in self.available_models: + self.available_models[key] = model + elif self.available_models[key].language_count > model.language_count: + self.available_models[key] = model + + def determine_required_translations(self, source_lang, target_lang): + direct_key = f'{source_lang}-{target_lang}' + if direct_key in self.available_models: + print( + f'Found direct translation from {source_lang} to {target_lang}.') + return [(source_lang, target_lang, direct_key)] + + print( + f'No direct translation from {source_lang} to {target_lang}. Trying to translate through en.') + + to_en_key = f'{source_lang}-en' + if to_en_key not in self.available_models: + print(f'No translation from {source_lang} to en.') + return [] + + from_en_key = f'en-{target_lang}' + if from_en_key not in self.available_models: + print(f'No translation from en to {target_lang}.') + return [] + + return [(source_lang, 'en', to_en_key), ('en', target_lang, from_en_key)] + + def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs): + self.load_available_models() + + translations = self.determine_required_translations( + source_lang, target_lang) + + if len(translations) == 0: + return sentences + + intermediate = sentences + for _, tgt_lang, key in translations: + model_data = self.available_models[key] + model_name = model_data.name + tokenizer, model = self.load_model(model_name) + model.to(device) + + if model_data.multilanguage: + alpha3 = to_alpha3_language(tgt_lang) + prefix = next( + x for x in tokenizer.supported_language_codes if alpha3 in x) + intermediate = [f'{prefix} {x}' for x in intermediate] + + inputs = tokenizer(intermediate, truncation=True, padding=True, + max_length=self.max_length, return_tensors="pt") + + for key in inputs: + inputs[key] = inputs[key].to(device) + + with torch.no_grad(): + translated = model.generate( + **inputs, num_beams=beam_size, **kwargs) + intermediate = [tokenizer.decode( + t, skip_special_tokens=True) for t in translated] + + return intermediate + + +class DownloadableModel: + def __init__(self, name): + self.name = name + source_languages, target_languages = self.parse_languages(name) + self.source_languages = source_languages + self.target_languages = target_languages + self.multilanguage = len(self.target_languages) > 1 + self.language_count = len( + self.source_languages) + len(self.target_languages) + + @staticmethod + def parse_languages(name): + parts = name.split('-') + if len(parts) > 5: + return set(), set() + + src, tgt = parts[3], parts[4] + return to_alpha2_languages(src.split('_')), to_alpha2_languages(tgt.split('_')) diff --git a/auto_subtitle/utils/ffmpeg.py b/auto_subtitle/utils/ffmpeg.py index 9f6fdd4..6950b19 100644 --- a/auto_subtitle/utils/ffmpeg.py +++ b/auto_subtitle/utils/ffmpeg.py @@ -2,7 +2,7 @@ import tempfile import ffmpeg from .mytempfile import MyTempFile -from .files import filename +from .files import filename, write_srt def get_audio(paths: list, audio_channel_index: int, sample_interval: list): @@ -38,7 +38,7 @@ def get_audio(paths: list, audio_channel_index: int, sample_interval: list): def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list): - for path, srt_path in subtitles.items(): + for path, subtitle in subtitles.items(): out_path = os.path.join(output_dir, f"{filename(path)}.mp4") print(f"Adding subtitles to {filename(path)}...") @@ -55,7 +55,9 @@ def overlay_subtitles(subtitles: dict, output_dir: str, sample_interval: list): # HACK: On Windows it's impossible to use absolute subtitle file path with ffmpeg # so we use temp copy instead # see: https://github.com/kkroening/ffmpeg-python/issues/745 - with MyTempFile(srt_path) as srt_temp: + with MyTempFile(subtitle['output_path'] if 'output_path' in subtitle else None) as srt_temp: + write_srt(subtitle['segments'], srt_temp.tmp_file) + video = ffmpeg.input(path, **ffmpeg_input_args) audio = video.audio diff --git a/auto_subtitle/utils/files.py b/auto_subtitle/utils/files.py index 8a9476b..ea40253 100644 --- a/auto_subtitle/utils/files.py +++ b/auto_subtitle/utils/files.py @@ -2,6 +2,7 @@ from typing import Iterator, TextIO from .convert import format_timestamp + def write_srt(transcript: Iterator[dict], file: TextIO): for i, segment in enumerate(transcript, start=1): print( @@ -13,5 +14,6 @@ def write_srt(transcript: Iterator[dict], file: TextIO): flush=True, ) + def filename(path: str): return os.path.splitext(os.path.basename(path))[0] diff --git a/auto_subtitle/utils/mytempfile.py b/auto_subtitle/utils/mytempfile.py index 372c74d..51e34ae 100644 --- a/auto_subtitle/utils/mytempfile.py +++ b/auto_subtitle/utils/mytempfile.py @@ -2,6 +2,7 @@ import os import shutil + class MyTempFile: """ A context manager for creating a temporary file in current directory, copying the content from @@ -18,15 +19,18 @@ class MyTempFile: Args: - file_path (str): The path to the file whose content will be copied to the temporary file. """ - def __init__(self, file_path): + + def __init__(self, file_path: str = None): self.file_path = file_path self.tmp_file = None self.tmp_file_path = None def __enter__(self): - self.tmp_file = tempfile.NamedTemporaryFile('w', dir='.', delete=False) + self.tmp_file = tempfile.NamedTemporaryFile('w', encoding="utf-8", dir='.', delete=False) self.tmp_file_path = os.path.relpath(self.tmp_file.name, '.') - shutil.copyfile(self.file_path, self.tmp_file_path) + + if self.file_path is not None and os.path.isfile(self.file_path): + shutil.copyfile(self.file_path, self.tmp_file_path) return self def __exit__(self, exc_type, exc_value, exc_traceback): diff --git a/auto_subtitle/utils/whisper.py b/auto_subtitle/utils/whisper.py index 9d21972..345a717 100644 --- a/auto_subtitle/utils/whisper.py +++ b/auto_subtitle/utils/whisper.py @@ -2,7 +2,7 @@ import faster_whisper from tqdm import tqdm -# pylint: disable=R0903 + class WhisperAI: """ Wrapper class for the Whisper speech recognition model with additional functionality. @@ -52,9 +52,13 @@ def transcribe(self, audio_path: str): - faster_whisper.TranscriptionSegment: An individual transcription segment. """ warnings.filterwarnings("ignore") - segments, info = self.model.transcribe(audio_path, **self.transcribe_args) + segments, info = self.model.transcribe( + audio_path, **self.transcribe_args) warnings.filterwarnings("default") + return (self.subtitles_iterator(segments, info), info) + + def subtitles_iterator(self, segments, info): # Same precision as the Whisper timestamps. total_duration = round(info.duration, 2) diff --git a/requirements.txt b/requirements.txt index eab95da..71da213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ faster-whisper==0.10.0 tqdm==4.56.0 -ffmpeg-python==0.2.0 \ No newline at end of file +ffmpeg-python==0.2.0 +wheel==0.42.0 +fasttext==0.9.2 +pybind11==2.11.1 +EasyNMT==2.0.2 +langcodes==3.3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index c185e54..2758ca7 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( - version="1.0", + version="1.1", name="faster_auto_subtitle", packages=find_packages(), py_modules=["auto_subtitle"], @@ -9,7 +9,10 @@ install_requires=[ 'faster-whisper', 'tqdm', - 'ffmpeg-python' + 'ffmpeg-python', + 'fasttext', + 'EasyNMT', + 'langcodes', ], description="Automatically generate and embed subtitles into your videos", entry_points={