-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]Add speaker recognition example #5327
Open
KPatr1ck
wants to merge
13
commits into
PaddlePaddle:dev-static
Choose a base branch
from
KPatr1ck:speaker_recognition
base: dev-static
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
136b8e2
Add speaker recognition example
KPatr1ck e0fd481
-
KPatr1ck 126a157
-
KPatr1ck bfd2c4f
Add speaker verification and compute EER
KPatr1ck ac25ffb
Add audio augmentation
KPatr1ck cb1df8f
Update usage of new APIs
KPatr1ck 0d0a582
-
KPatr1ck 914088a
-
KPatr1ck 53411d8
Merge remote-tracking branch 'update_stream/develop' into speaker_rec…
KPatr1ck 46f063a
-
KPatr1ck e87edc8
-
KPatr1ck 7d0ead8
Add score norm
KPatr1ck e5298de
-
KPatr1ck File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
|
||
class AngularMargin(nn.Layer): | ||
def __init__(self, margin=0.0, scale=1.0): | ||
super(AngularMargin, self).__init__() | ||
self.margin = margin | ||
self.scale = scale | ||
|
||
def forward(self, outputs, targets): | ||
outputs = outputs - self.margin * targets | ||
return self.scale * outputs | ||
|
||
|
||
class AdditiveAngularMargin(AngularMargin): | ||
def __init__(self, margin=0.0, scale=1.0, easy_margin=False): | ||
super(AdditiveAngularMargin, self).__init__(margin, scale) | ||
self.easy_margin = easy_margin | ||
|
||
self.cos_m = math.cos(self.margin) | ||
self.sin_m = math.sin(self.margin) | ||
self.th = math.cos(math.pi - self.margin) | ||
self.mm = math.sin(math.pi - self.margin) * self.margin | ||
|
||
def forward(self, outputs, targets): | ||
cosine = outputs.astype('float32') | ||
sine = paddle.sqrt(1.0 - paddle.pow(cosine, 2)) | ||
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) | ||
if self.easy_margin: | ||
phi = paddle.where(cosine > 0, phi, cosine) | ||
else: | ||
phi = paddle.where(cosine > self.th, phi, cosine - self.mm) | ||
outputs = (targets * phi) + ((1.0 - targets) * cosine) | ||
return self.scale * outputs | ||
|
||
|
||
class LogSoftmaxWrapper(nn.Layer): | ||
def __init__(self, loss_fn): | ||
super(LogSoftmaxWrapper, self).__init__() | ||
self.loss_fn = loss_fn | ||
self.criterion = paddle.nn.KLDivLoss(reduction="sum") | ||
|
||
def forward(self, outputs, targets, length=None): | ||
targets = F.one_hot(targets, outputs.shape[1]) | ||
try: | ||
predictions = self.loss_fn(outputs, targets) | ||
except TypeError: | ||
predictions = self.loss_fn(outputs) | ||
|
||
predictions = F.log_softmax(predictions, axis=1) | ||
loss = self.criterion(predictions, targets) / targets.sum() | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List | ||
|
||
import numpy as np | ||
from sklearn.metrics import roc_curve | ||
|
||
|
||
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]: | ||
''' | ||
Compute EER and return score threshold. | ||
''' | ||
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores) | ||
fnr = 1 - tpr | ||
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] | ||
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] | ||
return eer, eer_threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
|
||
class SpeakerClassifier(nn.Layer): | ||
def __init__( | ||
self, | ||
backbone, | ||
num_class, | ||
lin_blocks=0, | ||
lin_neurons=192, | ||
dropout=0.1, | ||
): | ||
|
||
super(SpeakerClassifier, self).__init__() | ||
self.backbone = backbone | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
input_size = self.backbone.emb_size | ||
self.blocks = nn.LayerList() | ||
for i in range(lin_blocks): | ||
self.blocks.extend([ | ||
nn.BatchNorm1D(input_size), | ||
nn.Linear(in_features=input_size, out_features=lin_neurons), | ||
]) | ||
input_size = lin_neurons | ||
|
||
self.weight = paddle.create_parameter( | ||
shape=(input_size, num_class), | ||
dtype='float32', | ||
attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()), | ||
) | ||
|
||
def forward(self, x, lengths=None): | ||
# x.shape: (N, C, L) | ||
x = self.backbone(x, lengths).squeeze( | ||
-1) # (N, emb_size, 1) -> (N, emb_size) | ||
x = self.dropout(x) | ||
|
||
for fc in self.blocks: | ||
x = fc(x) | ||
|
||
# KP: W和x的向量归一化,输出为余弦相似度,供Additive Angular Margin计算loss | ||
logits = F.linear(F.normalize(x), F.normalize(self.weight, axis=0)) | ||
|
||
return logits |
223 changes: 223 additions & 0 deletions
223
PaddleAudio/examples/speaker_recognition/signal_processing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
|
||
import numpy as np | ||
import paddle | ||
|
||
# TODO: Complete type-hint and doc string. | ||
|
||
|
||
def blackman_window(win_len, dtype=np.float32): | ||
arcs = np.pi * np.arange(win_len) / float(win_len) | ||
win = np.asarray( | ||
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs], | ||
dtype=dtype) | ||
return paddle.to_tensor(win) | ||
|
||
|
||
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"): | ||
if len(waveforms.shape) == 1: | ||
waveforms = waveforms.unsqueeze(0) | ||
|
||
assert amp_type in ["avg", "peak"] | ||
assert scale in ["linear", "dB"] | ||
|
||
if amp_type == "avg": | ||
if lengths is None: | ||
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True) | ||
else: | ||
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True) | ||
out = wav_sum / lengths | ||
elif amp_type == "peak": | ||
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True) | ||
else: | ||
raise NotImplementedError | ||
|
||
if scale == "linear": | ||
return out | ||
elif scale == "dB": | ||
return paddle.clip(20 * paddle.log10(out), min=-80) | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
def dB_to_amplitude(SNR): | ||
return 10**(SNR / 20) | ||
|
||
|
||
def convolve1d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convolved1d和conv1d的区别是什么? |
||
waveform, | ||
kernel, | ||
padding=0, | ||
pad_type="constant", | ||
stride=1, | ||
groups=1, | ||
): | ||
if len(waveform.shape) != 3: | ||
raise ValueError("Convolve1D expects a 3-dimensional tensor") | ||
|
||
# Padding can be a tuple (left_pad, right_pad) or an int | ||
if isinstance(padding, list): | ||
waveform = paddle.nn.functional.pad( | ||
x=waveform, | ||
pad=padding, | ||
mode=pad_type, | ||
data_format='NLC', | ||
) | ||
|
||
# Move time dimension last, which pad and fft and conv expect. | ||
# (N, L, C) -> (N, C, L) | ||
waveform = waveform.transpose([0, 2, 1]) | ||
kernel = kernel.transpose([0, 2, 1]) | ||
|
||
convolved = paddle.nn.functional.conv1d( | ||
x=waveform, | ||
weight=kernel, | ||
stride=stride, | ||
groups=groups, | ||
padding=padding if not isinstance(padding, list) else 0, | ||
) | ||
|
||
# Return time dimension to the second dimension. | ||
return convolved.transpose([0, 2, 1]) | ||
|
||
|
||
def notch_filter(notch_freq, filter_width=101, notch_width=0.05): | ||
# Check inputs | ||
assert 0 < notch_freq <= 1 | ||
assert filter_width % 2 != 0 | ||
pad = filter_width // 2 | ||
inputs = paddle.arange(filter_width, dtype='float32') - pad | ||
|
||
# Avoid frequencies that are too low | ||
notch_freq += notch_width | ||
|
||
# Define sinc function, avoiding division by zero | ||
def sinc(x): | ||
def _sinc(x): | ||
return paddle.sin(x) / x | ||
|
||
# The zero is at the middle index | ||
res = paddle.concat( | ||
[_sinc(x[:pad]), | ||
paddle.ones([1]), | ||
_sinc(x[pad + 1:])]) | ||
return res | ||
|
||
# Compute a low-pass filter with cutoff frequency notch_freq. | ||
hlpf = sinc(3 * (notch_freq - notch_width) * inputs) | ||
hlpf *= blackman_window(filter_width) | ||
hlpf /= paddle.sum(hlpf) | ||
|
||
# Compute a high-pass filter with cutoff frequency notch_freq. | ||
hhpf = sinc(3 * (notch_freq + notch_width) * inputs) | ||
hhpf *= blackman_window(filter_width) | ||
hhpf /= -paddle.sum(hhpf) | ||
hhpf[pad] += 1 | ||
|
||
# Adding filters creates notch filter | ||
return (hlpf + hhpf).reshape([1, -1, 1]) | ||
|
||
|
||
def reverberate(waveforms, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参考SpeechBrain看这些基础function是否有注释 |
||
rir_waveform, | ||
sample_rate, | ||
impulse_duration=0.3, | ||
rescale_amp="avg"): | ||
orig_shape = waveforms.shape | ||
|
||
if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3: | ||
raise NotImplementedError | ||
|
||
# if inputs are mono tensors we reshape to 1, samples | ||
if len(waveforms.shape) == 1: | ||
waveforms = waveforms.unsqueeze(0).unsqueeze(-1) | ||
elif len(waveforms.shape) == 2: | ||
waveforms = waveforms.unsqueeze(-1) | ||
|
||
if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor ! | ||
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1) | ||
elif len(rir_waveform.shape) == 2: | ||
rir_waveform = rir_waveform.unsqueeze(-1) | ||
|
||
# Compute the average amplitude of the clean | ||
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1], | ||
rescale_amp) | ||
|
||
# Compute index of the direct signal, so we can preserve alignment | ||
impulse_index_start = rir_waveform.abs().argmax(axis=1).item() | ||
impulse_index_end = min( | ||
impulse_index_start + int(sample_rate * impulse_duration), | ||
rir_waveform.shape[1]) | ||
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :] | ||
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2) | ||
rir_waveform = paddle.flip(rir_waveform, [1]) | ||
|
||
waveforms = convolve1d( | ||
waveform=waveforms, | ||
kernel=rir_waveform, | ||
padding=[rir_waveform.shape[1] - 1, 0], | ||
) | ||
|
||
# Rescale to the peak amplitude of the clean waveform | ||
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude, | ||
rescale_amp) | ||
|
||
if len(orig_shape) == 1: | ||
waveforms = waveforms.squeeze(0).squeeze(-1) | ||
if len(orig_shape) == 2: | ||
waveforms = waveforms.squeeze(-1) | ||
|
||
return waveforms | ||
|
||
|
||
def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"): | ||
assert amp_type in ["peak", "avg"] | ||
assert scale in ["linear", "dB"] | ||
|
||
batch_added = False | ||
if len(waveforms.shape) == 1: | ||
batch_added = True | ||
waveforms = waveforms.unsqueeze(0) | ||
|
||
waveforms = normalize(waveforms, lengths, amp_type) | ||
|
||
if scale == "linear": | ||
out = target_lvl * waveforms | ||
elif scale == "dB": | ||
out = dB_to_amplitude(target_lvl) * waveforms | ||
|
||
else: | ||
raise NotImplementedError("Invalid scale, choose between dB and linear") | ||
|
||
if batch_added: | ||
out = out.squeeze(0) | ||
|
||
return out | ||
|
||
|
||
def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14): | ||
assert amp_type in ["avg", "peak"] | ||
|
||
batch_added = False | ||
if len(waveforms.shape) == 1: | ||
batch_added = True | ||
waveforms = waveforms.unsqueeze(0) | ||
|
||
den = compute_amplitude(waveforms, lengths, amp_type) + eps | ||
if batch_added: | ||
waveforms = waveforms.squeeze(0) | ||
return waveforms / den |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用标准的注释
# NOTE(xxxxxgithuid or username): blabla