Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add speaker recognition example #5327

Open
wants to merge 13 commits into
base: dev-static
Choose a base branch
from
Empty file.
849 changes: 849 additions & 0 deletions PaddleAudio/examples/speaker_recognition/augment.py

Large diffs are not rendered by default.

70 changes: 70 additions & 0 deletions PaddleAudio/examples/speaker_recognition/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class AngularMargin(nn.Layer):
def __init__(self, margin=0.0, scale=1.0):
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale

def forward(self, outputs, targets):
outputs = outputs - self.margin * targets
return self.scale * outputs


class AdditiveAngularMargin(AngularMargin):
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin

self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin

def forward(self, outputs, targets):
cosine = outputs.astype('float32')
sine = paddle.sqrt(1.0 - paddle.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = paddle.where(cosine > 0, phi, cosine)
else:
phi = paddle.where(cosine > self.th, phi, cosine - self.mm)
outputs = (targets * phi) + ((1.0 - targets) * cosine)
return self.scale * outputs


class LogSoftmaxWrapper(nn.Layer):
def __init__(self, loss_fn):
super(LogSoftmaxWrapper, self).__init__()
self.loss_fn = loss_fn
self.criterion = paddle.nn.KLDivLoss(reduction="sum")

def forward(self, outputs, targets, length=None):
targets = F.one_hot(targets, outputs.shape[1])
try:
predictions = self.loss_fn(outputs, targets)
except TypeError:
predictions = self.loss_fn(outputs)

predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
29 changes: 29 additions & 0 deletions PaddleAudio/examples/speaker_recognition/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import numpy as np
from sklearn.metrics import roc_curve


def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
'''
Compute EER and return score threshold.
'''
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold
61 changes: 61 additions & 0 deletions PaddleAudio/examples/speaker_recognition/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class SpeakerClassifier(nn.Layer):
def __init__(
self,
backbone,
num_class,
lin_blocks=0,
lin_neurons=192,
dropout=0.1,
):

super(SpeakerClassifier, self).__init__()
self.backbone = backbone
self.dropout = nn.Dropout(dropout)

input_size = self.backbone.emb_size
self.blocks = nn.LayerList()
for i in range(lin_blocks):
self.blocks.extend([
nn.BatchNorm1D(input_size),
nn.Linear(in_features=input_size, out_features=lin_neurons),
])
input_size = lin_neurons

self.weight = paddle.create_parameter(
shape=(input_size, num_class),
dtype='float32',
attr=paddle.ParamAttr(initializer=nn.initializer.XavierUniform()),
)

def forward(self, x, lengths=None):
# x.shape: (N, C, L)
x = self.backbone(x, lengths).squeeze(
-1) # (N, emb_size, 1) -> (N, emb_size)
x = self.dropout(x)

for fc in self.blocks:
x = fc(x)

# KP: W和x的向量归一化,输出为余弦相似度,供Additive Angular Margin计算loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用标准的注释
# NOTE(xxxxxgithuid or username): blabla

logits = F.linear(F.normalize(x), F.normalize(self.weight, axis=0))

return logits
223 changes: 223 additions & 0 deletions PaddleAudio/examples/speaker_recognition/signal_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import paddle

# TODO: Complete type-hint and doc string.


def blackman_window(win_len, dtype=np.float32):
arcs = np.pi * np.arange(win_len) / float(win_len)
win = np.asarray(
[0.42 - 0.5 * np.cos(2 * arc) + 0.08 * np.cos(4 * arc) for arc in arcs],
dtype=dtype)
return paddle.to_tensor(win)


def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)

assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]

if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)
else:
raise NotImplementedError

if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80)
else:
raise NotImplementedError


def dB_to_amplitude(SNR):
return 10**(SNR / 20)


def convolve1d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convolved1d和conv1d的区别是什么?

waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1,
):
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")

# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, list):
waveform = paddle.nn.functional.pad(
x=waveform,
pad=padding,
mode=pad_type,
data_format='NLC',
)

# Move time dimension last, which pad and fft and conv expect.
# (N, L, C) -> (N, C, L)
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])

convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, list) else 0,
)

# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])


def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width, dtype='float32') - pad

# Avoid frequencies that are too low
notch_freq += notch_width

# Define sinc function, avoiding division by zero
def sinc(x):
def _sinc(x):
return paddle.sin(x) / x

# The zero is at the middle index
res = paddle.concat(
[_sinc(x[:pad]),
paddle.ones([1]),
_sinc(x[pad + 1:])])
return res

# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)

# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1

# Adding filters creates notch filter
return (hlpf + hhpf).reshape([1, -1, 1])


def reverberate(waveforms,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考SpeechBrain看这些基础function是否有注释

rir_waveform,
sample_rate,
impulse_duration=0.3,
rescale_amp="avg"):
orig_shape = waveforms.shape

if len(waveforms.shape) > 3 or len(rir_waveform.shape) > 3:
raise NotImplementedError

# if inputs are mono tensors we reshape to 1, samples
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0).unsqueeze(-1)
elif len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(-1)

if len(rir_waveform.shape) == 1: # convolve1d expects a 3d tensor !
rir_waveform = rir_waveform.unsqueeze(0).unsqueeze(-1)
elif len(rir_waveform.shape) == 2:
rir_waveform = rir_waveform.unsqueeze(-1)

# Compute the average amplitude of the clean
orig_amplitude = compute_amplitude(waveforms, waveforms.shape[1],
rescale_amp)

# Compute index of the direct signal, so we can preserve alignment
impulse_index_start = rir_waveform.abs().argmax(axis=1).item()
impulse_index_end = min(
impulse_index_start + int(sample_rate * impulse_duration),
rir_waveform.shape[1])
rir_waveform = rir_waveform[:, impulse_index_start:impulse_index_end, :]
rir_waveform = rir_waveform / paddle.norm(rir_waveform, p=2)
rir_waveform = paddle.flip(rir_waveform, [1])

waveforms = convolve1d(
waveform=waveforms,
kernel=rir_waveform,
padding=[rir_waveform.shape[1] - 1, 0],
)

# Rescale to the peak amplitude of the clean waveform
waveforms = rescale(waveforms, waveforms.shape[1], orig_amplitude,
rescale_amp)

if len(orig_shape) == 1:
waveforms = waveforms.squeeze(0).squeeze(-1)
if len(orig_shape) == 2:
waveforms = waveforms.squeeze(-1)

return waveforms


def rescale(waveforms, lengths, target_lvl, amp_type="avg", scale="linear"):
assert amp_type in ["peak", "avg"]
assert scale in ["linear", "dB"]

batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)

waveforms = normalize(waveforms, lengths, amp_type)

if scale == "linear":
out = target_lvl * waveforms
elif scale == "dB":
out = dB_to_amplitude(target_lvl) * waveforms

else:
raise NotImplementedError("Invalid scale, choose between dB and linear")

if batch_added:
out = out.squeeze(0)

return out


def normalize(waveforms, lengths=None, amp_type="avg", eps=1e-14):
assert amp_type in ["avg", "peak"]

batch_added = False
if len(waveforms.shape) == 1:
batch_added = True
waveforms = waveforms.unsqueeze(0)

den = compute_amplitude(waveforms, lengths, amp_type) + eps
if batch_added:
waveforms = waveforms.squeeze(0)
return waveforms / den
Loading