diff --git a/edit_utils.py b/edit_utils.py
index a9683f8..990de64 100644
--- a/edit_utils.py
+++ b/edit_utils.py
@@ -1,49 +1,128 @@
-def get_span(orig, new, editType):
- orig_list = orig.split(" ")
- new_list = new.split(" ")
-
- flag = False # this indicate whether the actual edit follow the specified editType
- if editType == "deletion":
- assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}"
- diff = len(orig_list) - len(new_list)
- for i, (o, n) in enumerate(zip(orig_list, new_list)):
- if o != n: # assume the index of the first different word is the starting index of the orig_span
-
- orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part
- new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part
- flag = True
- break
-
-
- elif editType == "insertion":
- assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}"
- diff = len(new_list) - len(orig_list)
- for i, (o, n) in enumerate(zip(orig_list, new_list)):
- if o != n: # insertion is just the opposite of deletion
- new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same
- orig_span = [i-1, i]
- flag = True
- break
-
- elif editType == "substitution":
- new_span = []
- orig_span = []
- for i, (o, n) in enumerate(zip(orig_list, new_list)):
- if o != n:
- new_span = [i]
- orig_span = [i]
- break
- assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}"
- for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])):
- if o != n:
- new_span.append(len(new_list) - j -1)
- orig_span.append(len(orig_list) - j - 1)
- flag = True
- break
- else:
- raise RuntimeError(f"editType unknown: {editType}")
-
- if not flag:
- raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}")
-
- return orig_span, new_span
\ No newline at end of file
+import re
+
+
+def levenshtein_distance(word1, word2):
+ len1, len2 = len(word1), len(word2)
+ # Initialize a matrix to store the edit distances, operations, and positions
+ dp = [[(0, "", []) for _ in range(len2 + 1)] for _ in range(len1 + 1)]
+
+ # Initialize the first row and column
+ for i in range(len1 + 1):
+ dp[i][0] = (i, "d" * i)
+ for j in range(len2 + 1):
+ dp[0][j] = (j, "i" * j)
+
+ # Fill in the rest of the matrix
+ for i in range(1, len1 + 1):
+ for j in range(1, len2 + 1):
+ cost = 0 if word1[i - 1] == word2[j - 1] else 1
+ # Minimum of deletion, insertion, or substitution
+ deletion = dp[i - 1][j][0] + 1
+ insertion = dp[i][j - 1][0] + 1
+ substitution = dp[i - 1][j - 1][0] + cost
+ min_dist = min(deletion, insertion, substitution)
+
+ # which operation led to the minimum distance
+ if min_dist == deletion:
+ operation = dp[i - 1][j][1] + "d"
+ elif min_dist == insertion:
+ operation = dp[i][j - 1][1] + "i"
+ else:
+ operation = dp[i - 1][j - 1][1] + ("s" if cost else "=")
+
+ dp[i][j] = (min_dist, operation)
+
+ # min edit distance, list of operations, positions of operations
+ return dp[len1][len2][0], dp[len1][len2][1]
+
+
+def extract_words(sentence):
+ words = re.findall(r"\b[\w']+\b", sentence)
+ return words
+
+
+# edge cases for spans of deletion, insertion, substitution
+def handle_delete(start, end, orig, new):
+ orig.append([start, end - 1])
+ new.append([start - 1, start])
+
+
+def handle_insert(start, end, orig, new):
+ temp_new = [start - 1, start]
+ orig.append(temp_new)
+ new.append(orig[-1])
+ orig[-1], new[-1] = new[-1], temp_new
+
+
+def handle_substitute(start, end, orig, new):
+ orig.append([start, end - 1])
+ new.append([start, end - 1])
+
+
+# editing the last index of the sentence is another edge case
+def handle_last_operation(prev_op, start, end, orig, new):
+ if prev_op == "d":
+ handle_delete(start, end, orig, new)
+ elif prev_op == "i":
+ handle_insert(start, end, orig, new)
+ elif prev_op == "s":
+ handle_substitute(start, end, orig, new)
+
+
+# adjust spans according to edge case expected output
+def adjust_last_span(operations, orig, new):
+ if operations[-1] == "d":
+ new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
+ orig[-1] = [orig[-1][0] - 1, orig[-1][0] - 1]
+ elif operations[-1] == "i":
+ new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
+ orig[-1] = [orig[-1][0] - 1, orig[-1][0]]
+
+
+def get_spans(operations):
+ orig = []
+ new = []
+ prev_op = None
+ start = 0
+ end = 0
+ for i, op in enumerate(operations):
+ # prevent span duplication of sequential edits of the same type
+ if op != "=":
+ if op != prev_op:
+ if prev_op:
+ handle_last_operation(prev_op, start, end, orig, new)
+ prev_op = op
+ start = i
+ end = i + 1
+ else:
+ if prev_op:
+ handle_last_operation(prev_op, start, end, orig, new)
+ prev_op = None
+ start = end
+ # edge case of last operation
+ if prev_op:
+ handle_last_operation(prev_op, start, end, orig, new)
+ adjust_last_span(operations, orig, new)
+ return orig, new
+
+
+def get_edits(operations):
+ used_edits = []
+ prev_op = ""
+ for op in operations:
+ if op == "i" and prev_op != "i":
+ used_edits.append("insertion")
+ elif op == "d" and prev_op != "d":
+ used_edits.append("deletion")
+ elif op == "s" and prev_op != "s":
+ used_edits.append("substitution")
+ prev_op = op
+ return used_edits
+
+
+def parse_edit(orig_transcript, trgt_transcript):
+ word1 = extract_words(orig_transcript)
+ word2 = extract_words(trgt_transcript)
+ distance, operations = levenshtein_distance(word1, word2)
+ orig_span, new_span = get_spans(operations)
+ return operations, orig_span, new_span
diff --git a/inference_speech_editing.ipynb b/inference_speech_editing.ipynb
index a0b5cd5..852d4f0 100644
--- a/inference_speech_editing.ipynb
+++ b/inference_speech_editing.ipynb
@@ -2,68 +2,54 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 24,
"metadata": {},
"outputs": [],
- "source": [
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"7\"\n",
- "os.environ[\"USER\"] = \"YOUR_USERNAME\" # TODO change this to your username"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/pyp/miniconda3/envs/voicecraft/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- }
- ],
"source": [
"# import libs\n",
"import torch\n",
"import torchaudio\n",
+ "import os\n",
"import numpy as np\n",
"import random\n",
- "from argparse import Namespace\n",
+ "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
+ "os.environ[\"USER\"] = \"YOUR_USERNAME\" # TODO change this to your username\n",
"\n",
"from data.tokenizer import (\n",
" AudioTokenizer,\n",
" TextTokenizer,\n",
")\n",
+ "from inference_speech_editing_scale import get_mask_interval, inference_one_sample\n",
+ "from edit_utils import get_edits, parse_edit\n",
"\n",
+ "from argparse import Namespace\n",
"from models import voicecraft"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# install MFA models and dictionaries if you haven't done so already\n",
- "!source ~/.bashrc && \\\n",
- " conda activate voicecraft && \\\n",
- " mfa model download dictionary english_us_arpa && \\\n",
- " mfa model download acoustic english_us_arpa"
+ "# !source ~/.bashrc && \\\n",
+ "# conda activate voicecraft && \\\n",
+ "# mfa model download dictionary english_us_arpa && \\\n",
+ "# mfa model download acoustic english_us_arpa"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hyperparameters for inference\n",
"left_margin = 0.08\n",
"right_margin = 0.08\n",
+ "sub_amount = 0.01\n",
"codec_audio_sr = 16000\n",
"codec_sr = 50\n",
"top_k = 0\n",
@@ -89,7 +75,6 @@
"voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n",
"\n",
"# the new way of loading the model, with huggingface, recommended\n",
- "from models import voicecraft\n",
"model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n",
"phn2num = model.args.phn2num\n",
"config = vars(model.args)\n",
@@ -139,103 +124,74 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "original:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "edited:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "outputs": [],
"source": [
- "editTypes_set = set(['substitution', 'insertion', 'deletion'])\n",
"# propose what do you want the target modified transcript to be\n",
- "target_transcript = \"But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,\"\n",
- "edit_type = \"substitution\"\n",
- "assert edit_type in editTypes_set, f\"Invalid edit type {edit_type}. Must be one of {editTypes_set}.\"\n",
- "\n",
- "# if you want to do a second modification on top of the first one, write down the second modification (target_transcript2, type_of_modification2)\n",
- "# make sure the two modification do not overlap, if they do, you need to combine them into one modification\n",
+ "target_transcript = \"But when I had approached so near, that the sense deceives, Lost not by farness any of its marks,\"\n",
+ "print(\"orig: \", orig_transcript)\n",
+ "print(\"trgt: \", target_transcript)\n",
"\n",
"# run the script to turn user input to the format that the model can take\n",
- "from edit_utils import get_span\n",
- "orig_span, new_span = get_span(orig_transcript, target_transcript, edit_type)\n",
- "if orig_span[0] > orig_span[1]:\n",
- " RuntimeError(f\"example {audio_fn} failed\")\n",
- "if orig_span[0] == orig_span[1]:\n",
- " orig_span_save = [orig_span[0]]\n",
- "else:\n",
- " orig_span_save = orig_span\n",
- "if new_span[0] == new_span[1]:\n",
- " new_span_save = [new_span[0]]\n",
- "else:\n",
- " new_span_save = new_span\n",
- "\n",
- "orig_span_save = \",\".join([str(item) for item in orig_span_save])\n",
- "new_span_save = \",\".join([str(item) for item in new_span_save])\n",
- "from inference_speech_editing_scale import get_mask_interval\n",
+ "operations, orig_span, new_span = parse_edit(orig_transcript, target_transcript)\n",
+ "if operations[-1] == 'i':\n",
+ " raise RuntimeError(\"The last operation should not be insertion. Please use text to speech instead\")\n",
+ "print(operations)\n",
+ "used_edits = get_edits(operations)\n",
+ "print(used_edits)\n",
+ "\n",
+ "def process_span(span):\n",
+ " if span[0] > span[1]:\n",
+ " raise RuntimeError(f\"example {audio_fn} failed\")\n",
+ " if span[0] == span[1]:\n",
+ " return [span[0]]\n",
+ " return span\n",
+ "\n",
+ "print(\"orig_span: \", orig_span)\n",
+ "print(\"new_span: \", new_span)\n",
+ "orig_span_save = [process_span(span) for span in orig_span]\n",
+ "new_span_save = [process_span(span) for span in new_span]\n",
+ "\n",
+ "orig_span_saves = [\",\".join([str(item) for item in span]) for span in orig_span_save]\n",
+ "new_span_saves = [\",\".join([str(item) for item in span]) for span in new_span_save]\n",
+ "\n",
+ "starting_intervals = []\n",
+ "ending_intervals = []\n",
+ "for i, orig_span_save in enumerate(orig_span_saves):\n",
+ " start, end = get_mask_interval(align_fn, orig_span_save, used_edits[i])\n",
+ " starting_intervals.append(start)\n",
+ " ending_intervals.append(end)\n",
+ "\n",
+ "print(\"intervals: \", starting_intervals, ending_intervals)\n",
"\n",
- "start, end = get_mask_interval(align_fn, orig_span_save, edit_type)\n",
"info = torchaudio.info(audio_fn)\n",
"audio_dur = info.num_frames / info.sample_rate\n",
- "morphed_span = (max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur)) # in seconds\n",
- "\n",
- "# span in codec frames\n",
- "mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]\n",
- "mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n",
"\n",
+ "def resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount):\n",
+ " while True:\n",
+ " morphed_span = [(max(start - left_margin, 1/codec_sr), min(end + right_margin, audio_dur))\n",
+ " for start, end in zip(starting_intervals, ending_intervals)] # in seconds\n",
+ " mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]\n",
+ " # Check for overlap\n",
+ " overlapping = any(a[1] >= b[0] for a, b in zip(mask_interval, mask_interval[1:]))\n",
+ " if not overlapping:\n",
+ " break\n",
+ " \n",
+ " # Reduce margins\n",
+ " left_margin -= sub_amount\n",
+ " right_margin -= sub_amount\n",
+ " \n",
+ " return mask_interval\n",
"\n",
"\n",
+ "# span in codec frames\n",
+ "mask_interval = resolve_overlap(starting_intervals, ending_intervals, audio_dur, codec_sr, left_margin, right_margin, sub_amount)\n",
+ "mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now\n",
"# run the model to get the output\n",
- "from inference_speech_editing_scale import inference_one_sample\n",
- "\n",
"decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens}\n",
"orig_audio, new_audio = inference_one_sample(model, Namespace(**config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, mask_interval, device, decode_config)\n",
- " \n",
+ "\n",
"# save segments for comparison\n",
"orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()\n",
"# logging.info(f\"length of the resynthesize orig audio: {orig_audio.shape}\")\n",
@@ -265,18 +221,11 @@
"# torchaudio.save(save_fn_orig, orig_audio, codec_audio_sr)\n",
"\n",
"# # if you get error importing T5 in transformers\n",
- "# # try \n",
+ "# # try\n",
"# # pip uninstall Pillow\n",
"# # pip install Pillow\n",
"# # you are likely to get warning looks like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
@@ -295,8 +244,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.18"
- }
+ "version": "3.9.16"
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
},
"nbformat": 4,
"nbformat_minor": 2