diff --git a/edit_utils.py b/edit_utils.py index a9683f8..990de64 100644 --- a/edit_utils.py +++ b/edit_utils.py @@ -1,49 +1,128 @@ -def get_span(orig, new, editType): - orig_list = orig.split(" ") - new_list = new.split(" ") - - flag = False # this indicate whether the actual edit follow the specified editType - if editType == "deletion": - assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}" - diff = len(orig_list) - len(new_list) - for i, (o, n) in enumerate(zip(orig_list, new_list)): - if o != n: # assume the index of the first different word is the starting index of the orig_span - - orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part - new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part - flag = True - break - - - elif editType == "insertion": - assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}" - diff = len(new_list) - len(orig_list) - for i, (o, n) in enumerate(zip(orig_list, new_list)): - if o != n: # insertion is just the opposite of deletion - new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same - orig_span = [i-1, i] - flag = True - break - - elif editType == "substitution": - new_span = [] - orig_span = [] - for i, (o, n) in enumerate(zip(orig_list, new_list)): - if o != n: - new_span = [i] - orig_span = [i] - break - assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}" - for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])): - if o != n: - new_span.append(len(new_list) - j -1) - orig_span.append(len(orig_list) - j - 1) - flag = True - break - else: - raise RuntimeError(f"editType unknown: {editType}") - - if not flag: - raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}") - - return orig_span, new_span \ No newline at end of file +import re + + +def levenshtein_distance(word1, word2): + len1, len2 = len(word1), len(word2) + # Initialize a matrix to store the edit distances, operations, and positions + dp = [[(0, "", []) for _ in range(len2 + 1)] for _ in range(len1 + 1)] + + # Initialize the first row and column + for i in range(len1 + 1): + dp[i][0] = (i, "d" * i) + for j in range(len2 + 1): + dp[0][j] = (j, "i" * j) + + # Fill in the rest of the matrix + for i in range(1, len1 + 1): + for j in range(1, len2 + 1): + cost = 0 if word1[i - 1] == word2[j - 1] else 1 + # Minimum of deletion, insertion, or substitution + deletion = dp[i - 1][j][0] + 1 + insertion = dp[i][j - 1][0] + 1 + substitution = dp[i - 1][j - 1][0] + cost + min_dist = min(deletion, insertion, substitution) + + # which operation led to the minimum distance + if min_dist == deletion: + operation = dp[i - 1][j][1] + "d" + elif min_dist == insertion: + operation = dp[i][j - 1][1] + "i" + else: + operation = dp[i - 1][j - 1][1] + ("s" if cost else "=") + + dp[i][j] = (min_dist, operation) + + # min edit distance, list of operations, positions of operations + return dp[len1][len2][0], dp[len1][len2][1] + + +def extract_words(sentence): + words = re.findall(r"\b[\w']+\b", sentence) + return words + + +# edge cases for spans of deletion, insertion, substitution +def handle_delete(start, end, orig, new): + orig.append([start, end - 1]) + new.append([start - 1, start]) + + +def handle_insert(start, end, orig, new): + temp_new = [start - 1, start] + orig.append(temp_new) + new.append(orig[-1]) + orig[-1], new[-1] = new[-1], temp_new + + +def handle_substitute(start, end, orig, new): + orig.append([start, end - 1]) + new.append([start, end - 1]) + + +# editing the last index of the sentence is another edge case +def handle_last_operation(prev_op, start, end, orig, new): + if prev_op == "d": + handle_delete(start, end, orig, new) + elif prev_op == "i": + handle_insert(start, end, orig, new) + elif prev_op == "s": + handle_substitute(start, end, orig, new) + + +# adjust spans according to edge case expected output +def adjust_last_span(operations, orig, new): + if operations[-1] == "d": + new[-1] = [new[-1][0] - 1, new[-1][1] - 1] + orig[-1] = [orig[-1][0] - 1, orig[-1][0] - 1] + elif operations[-1] == "i": + new[-1] = [new[-1][0] - 1, new[-1][1] - 1] + orig[-1] = [orig[-1][0] - 1, orig[-1][0]] + + +def get_spans(operations): + orig = [] + new = [] + prev_op = None + start = 0 + end = 0 + for i, op in enumerate(operations): + # prevent span duplication of sequential edits of the same type + if op != "=": + if op != prev_op: + if prev_op: + handle_last_operation(prev_op, start, end, orig, new) + prev_op = op + start = i + end = i + 1 + else: + if prev_op: + handle_last_operation(prev_op, start, end, orig, new) + prev_op = None + start = end + # edge case of last operation + if prev_op: + handle_last_operation(prev_op, start, end, orig, new) + adjust_last_span(operations, orig, new) + return orig, new + + +def get_edits(operations): + used_edits = [] + prev_op = "" + for op in operations: + if op == "i" and prev_op != "i": + used_edits.append("insertion") + elif op == "d" and prev_op != "d": + used_edits.append("deletion") + elif op == "s" and prev_op != "s": + used_edits.append("substitution") + prev_op = op + return used_edits + + +def parse_edit(orig_transcript, trgt_transcript): + word1 = extract_words(orig_transcript) + word2 = extract_words(trgt_transcript) + distance, operations = levenshtein_distance(word1, word2) + orig_span, new_span = get_spans(operations) + return operations, orig_span, new_span diff --git a/inference_speech_editing.ipynb b/inference_speech_editing.ipynb index a0b5cd5..852d4f0 100644 --- a/inference_speech_editing.ipynb +++ b/inference_speech_editing.ipynb @@ -2,68 +2,54 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 24, "metadata": {}, "outputs": [], - "source": [ - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"7\"\n", - "os.environ[\"USER\"] = \"YOUR_USERNAME\" # TODO change this to your username" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], "source": [ "# import libs\n", "import torch\n", "import torchaudio\n", + "import os\n", "import numpy as np\n", "import random\n", - "from argparse import Namespace\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "os.environ[\"USER\"] = \"YOUR_USERNAME\" # TODO change this to your username\n", "\n", "from data.tokenizer import (\n", " AudioTokenizer,\n", " TextTokenizer,\n", ")\n", + "from inference_speech_editing_scale import get_mask_interval, inference_one_sample\n", + "from edit_utils import get_edits, parse_edit\n", "\n", + "from argparse import Namespace\n", "from models import voicecraft" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# install MFA models and dictionaries if you haven't done so already\n", - "!source ~/.bashrc && \\\n", - " conda activate voicecraft && \\\n", - " mfa model download dictionary english_us_arpa && \\\n", - " mfa model download acoustic english_us_arpa" + "# !source ~/.bashrc && \\\n", + "# conda activate voicecraft && \\\n", + "# mfa model download dictionary english_us_arpa && \\\n", + "# mfa model download acoustic english_us_arpa" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hyperparameters for inference\n", "left_margin = 0.08\n", "right_margin = 0.08\n", + "sub_amount = 0.01\n", "codec_audio_sr = 16000\n", "codec_sr = 50\n", "top_k = 0\n", @@ -89,7 +75,6 @@ "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "\n", "# the new way of loading the model, with huggingface, recommended\n", - "from models import voicecraft\n", "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "phn2num = model.args.phn2num\n", "config = vars(model.args)\n", @@ -139,103 +124,74 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "original:\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "edited:\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "editTypes_set = set(['substitution', 'insertion', 'deletion'])\n", "# propose what do you want the target modified transcript to be\n", - "target_transcript = \"But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,\"\n", - "edit_type = \"substitution\"\n", - "assert edit_type in editTypes_set, f\"Invalid edit type {edit_type}. Must be one of {editTypes_set}.\"\n", - "\n", - "# if you want to do a second modification on top of the first one, write down the second modification (target_transcript2, type_of_modification2)\n", - "# make sure the two modification do not overlap, if they do, you need to combine them into one modification\n", + "target_transcript = \"But when I had approached so near, that the sense deceives, Lost not by farness any of its marks,\"\n", + "print(\"orig: \", orig_transcript)\n", + "print(\"trgt: \", target_transcript)\n", "\n", "# run the script to turn user input to the format that the model can take\n", - "from edit_utils import get_span\n", - "orig_span, new_span = get_span(orig_transcript, target_transcript, edit_type)\n", - "if orig_span[0] > orig_span[1]:\n", - " RuntimeError(f\"example {audio_fn} failed\")\n", - "if orig_span[0] == orig_span[1]:\n", - " orig_span_save = [orig_span[0]]\n", - "else:\n", - " orig_span_save = orig_span\n", - "if new_span[0] == new_span[1]:\n", - " new_span_save = [new_span[0]]\n", - "else:\n", - " new_span_save = new_span\n", - "\n", - "orig_span_save = \",\".join([str(item) for item in orig_span_save])\n", - "new_span_save = \",\".join([str(item) for item in new_span_save])\n", - "from inference_speech_editing_scale import get_mask_interval\n", + "operations, orig_span, new_span = parse_edit(orig_transcript, target_transcript)\n", + "if operations[-1] == 'i':\n", + " raise RuntimeError(\"The last operation should not be insertion. Please use text to speech instead\")\n", + "print(operations)\n", + "used_edits = get_edits(operations)\n", + "print(used_edits)\n", + "\n", + "def process_span(span):\n", + " if span[0] > span[1]:\n", + " raise RuntimeError(f\"example {audio_fn} failed\")\n", + " if span[0] == span[1]:\n", + " return [span[0]]\n", + " return span\n", + "\n", + "print(\"orig_span: \", orig_span)\n", + "print(\"new_span: \", new_span)\n", + "orig_span_save = [process_span(span) for span in orig_span]\n", + "new_span_save = [process_span(span) for span in new_span]\n", + "\n", + "orig_span_saves = [\",\".join([str(item) for item in span]) for span in orig_span_save]\n", + "new_span_saves = [\",\".join([str(item) for item in span]) for span in new_span_save]\n", + "\n", + "starting_intervals = []\n", + "ending_intervals = []\n", + "for i, orig_span_save in enumerate(orig_span_saves):\n", + " start, end = get_mask_interval(align_fn, orig_span_save, used_edits[i])\n", + " starting_intervals.append(start)\n", + " ending_intervals.append(end)\n", + "\n", + "print(\"intervals: \", starting_intervals, ending_intervals)\n", "\n", - "start, end = get_mask_interval(align_fn, orig_span_save, edit_type)\n", "info = torchaudio.info(audio_fn)\n", "audio_dur = info.num_frames / info.sample_rate\n", - "morphed_span = (max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur)) # in seconds\n", - "\n", - "# span in codec frames\n", - "mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]\n", - "mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n", "\n", + "def resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount):\n", + " while True:\n", + " morphed_span = [(max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur))\n", + " for start, end in zip(starting_intervals, ending_intervals)] # in seconds\n", + " mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]\n", + " # Check for overlap\n", + " overlapping = any(a[1] >= b[0] for a, b in zip(mask_interval, mask_interval[1:]))\n", + " if not overlapping:\n", + " break\n", + " \n", + " # Reduce margins\n", + " left_margin -= sub_amount\n", + " right_margin -= sub_amount\n", + " \n", + " return mask_interval\n", "\n", "\n", + "# span in codec frames\n", + "mask_interval = resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount)\n", + "mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n", "# run the model to get the output\n", - "from inference_speech_editing_scale import inference_one_sample\n", - "\n", "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens}\n", "orig_audio, new_audio = inference_one_sample(model, Namespace(**config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, mask_interval, device, decode_config)\n", - " \n", + "\n", "# save segments for comparison\n", "orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()\n", "# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n", @@ -265,18 +221,11 @@ "# torchaudio.save(save_fn_orig, orig_audio, codec_audio_sr)\n", "\n", "# # if you get error importing T5 in transformers\n", - "# # try \n", + "# # try\n", "# # pip uninstall Pillow\n", "# # pip install Pillow\n", "# # you are likely to get warning looks like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -295,8 +244,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" - } + "version": "3.9.16" + }, + "nbformat": 4, + "nbformat_minor": 0 }, "nbformat": 4, "nbformat_minor": 2