Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-edit capabilities to Speech Editing #94

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 128 additions & 49 deletions edit_utils.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,128 @@
def get_span(orig, new, editType):
orig_list = orig.split(" ")
new_list = new.split(" ")

flag = False # this indicate whether the actual edit follow the specified editType
if editType == "deletion":
assert len(orig_list) > len(new_list), f"the edit type is deletion, but new is not shorter than original:\n new: {new}\n orig: {orig}"
diff = len(orig_list) - len(new_list)
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n: # assume the index of the first different word is the starting index of the orig_span

orig_span = [i, i + diff - 1] # assume that the indices are starting and ending index of the deleted part
new_span = [i-1, i] # but for the new span, the starting and ending index is the two words that surround the deleted part
flag = True
break


elif editType == "insertion":
assert len(orig_list) < len(new_list), f"the edit type is insertion, but the new is not longer than the original:\n new: {new}\n orig: {orig}"
diff = len(new_list) - len(orig_list)
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n: # insertion is just the opposite of deletion
new_span = [i, i + diff - 1] # NOTE if only inserted one word, s and e will be the same
orig_span = [i-1, i]
flag = True
break

elif editType == "substitution":
new_span = []
orig_span = []
for i, (o, n) in enumerate(zip(orig_list, new_list)):
if o != n:
new_span = [i]
orig_span = [i]
break
assert len(new_span) == 1 and len(orig_span) == 1, f"new_span: {new_span}, orig_span: {orig_span}"
for j, (o, n) in enumerate(zip(orig_list[::-1], new_list[::-1])):
if o != n:
new_span.append(len(new_list) - j -1)
orig_span.append(len(orig_list) - j - 1)
flag = True
break
else:
raise RuntimeError(f"editType unknown: {editType}")

if not flag:
raise RuntimeError(f"wrong editing with the specified edit type:\n original: {orig}\n new: {new}\n, editType: {editType}")

return orig_span, new_span
import re


def levenshtein_distance(word1, word2):
len1, len2 = len(word1), len(word2)
# Initialize a matrix to store the edit distances, operations, and positions
dp = [[(0, "", []) for _ in range(len2 + 1)] for _ in range(len1 + 1)]

# Initialize the first row and column
for i in range(len1 + 1):
dp[i][0] = (i, "d" * i)
for j in range(len2 + 1):
dp[0][j] = (j, "i" * j)

# Fill in the rest of the matrix
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
cost = 0 if word1[i - 1] == word2[j - 1] else 1
# Minimum of deletion, insertion, or substitution
deletion = dp[i - 1][j][0] + 1
insertion = dp[i][j - 1][0] + 1
substitution = dp[i - 1][j - 1][0] + cost
min_dist = min(deletion, insertion, substitution)

# which operation led to the minimum distance
if min_dist == deletion:
operation = dp[i - 1][j][1] + "d"
elif min_dist == insertion:
operation = dp[i][j - 1][1] + "i"
else:
operation = dp[i - 1][j - 1][1] + ("s" if cost else "=")

dp[i][j] = (min_dist, operation)

# min edit distance, list of operations, positions of operations
return dp[len1][len2][0], dp[len1][len2][1]


def extract_words(sentence):
words = re.findall(r"\b[\w']+\b", sentence)
return words


# edge cases for spans of deletion, insertion, substitution
def handle_delete(start, end, orig, new):
orig.append([start, end - 1])
new.append([start - 1, start])


def handle_insert(start, end, orig, new):
temp_new = [start - 1, start]
orig.append(temp_new)
new.append(orig[-1])
orig[-1], new[-1] = new[-1], temp_new


def handle_substitute(start, end, orig, new):
orig.append([start, end - 1])
new.append([start, end - 1])


# editing the last index of the sentence is another edge case
def handle_last_operation(prev_op, start, end, orig, new):
if prev_op == "d":
handle_delete(start, end, orig, new)
elif prev_op == "i":
handle_insert(start, end, orig, new)
elif prev_op == "s":
handle_substitute(start, end, orig, new)


# adjust spans according to edge case expected output
def adjust_last_span(operations, orig, new):
if operations[-1] == "d":
new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
orig[-1] = [orig[-1][0] - 1, orig[-1][0] - 1]
elif operations[-1] == "i":
new[-1] = [new[-1][0] - 1, new[-1][1] - 1]
orig[-1] = [orig[-1][0] - 1, orig[-1][0]]


def get_spans(operations):
orig = []
new = []
prev_op = None
start = 0
end = 0
for i, op in enumerate(operations):
# prevent span duplication of sequential edits of the same type
if op != "=":
if op != prev_op:
if prev_op:
handle_last_operation(prev_op, start, end, orig, new)
prev_op = op
start = i
end = i + 1
else:
if prev_op:
handle_last_operation(prev_op, start, end, orig, new)
prev_op = None
start = end
# edge case of last operation
if prev_op:
handle_last_operation(prev_op, start, end, orig, new)
adjust_last_span(operations, orig, new)
return orig, new


def get_edits(operations):
used_edits = []
prev_op = ""
for op in operations:
if op == "i" and prev_op != "i":
used_edits.append("insertion")
elif op == "d" and prev_op != "d":
used_edits.append("deletion")
elif op == "s" and prev_op != "s":
used_edits.append("substitution")
prev_op = op
return used_edits


def parse_edit(orig_transcript, trgt_transcript):
word1 = extract_words(orig_transcript)
word2 = extract_words(trgt_transcript)
distance, operations = levenshtein_distance(word1, word2)
orig_span, new_span = get_spans(operations)
return operations, orig_span, new_span
197 changes: 74 additions & 123 deletions inference_speech_editing.ipynb

Large diffs are not rendered by default.