-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix probability computation in WhisperNoSpeechDetection
when recomputing scores
#29248
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For whisper, num_beams > 1
paths seems a bit tricky. Would you mind adding a test to make sure we have the expected new results?
@ArthurZucker I added a slow test where I set the |
Hey @cifkao, thanks for the PR!
Also, could you point out why this scenario applies when language is set ? I understand the case for num_beams>1 but don't see the point for the other case! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ylacombe The bug manifests only when both conditions ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this fix @cifkao!
if input_ids.shape[1] == self.begin_index: | ||
if self.start_of_trans_offset > 1: | ||
with torch.no_grad(): | ||
logits = self.model(**self.inputs).logits | ||
|
||
no_speech_index = self.begin_index - self.start_of_trans_offset | ||
no_speech_scores = logits[:, no_speech_index] | ||
is_scores_logprobs = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
@@ -2615,6 +2615,59 @@ def test_whisper_longform_multi_batch_hard_prev_cond(self): | |||
for i in range(num_samples): | |||
assert decoded_all[i] == EXPECTED_TEXT[i] | |||
|
|||
@slow | |||
def test_whisper_longform_no_speech_detection(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this slow test! Just to confirm, before the fix all the transcriptions are empty due to the no-speech probabilities exceeding 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly.
Ready for final review from @ArthurZucker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯 awesome catch!
Thank you for taking the time to add a test 🔥 |
Thanks for the contribution @cifkao! |
…uting scores (#29248) * Fix is_scores_logprobs in WhisperNoSpeechDetection * Add test_whisper_longform_no_speech_detection * Fix typo
…uting scores (#29248) * Fix is_scores_logprobs in WhisperNoSpeechDetection * Add test_whisper_longform_no_speech_detection * Fix typo
What does this PR do?
Fix #29313.
Before submitting
Who can review?
@patrickvonplaten @sanchit-gandhi @ylacombe