-
Notifications
You must be signed in to change notification settings - Fork 344
/
prepare.py
121 lines (106 loc) · 4.42 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os, csv, argparse
import sys
import torch, torchaudio, timm
import numpy as np
from torch.cuda.amp import autocast
import IPython
current_directory = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_directory)
from src.models import ASTModel
# Create a new class that inherits the original ASTModel class
class ASTModelVis(ASTModel):
def get_att_map(self, block, x):
qkv = block.attn.qkv
num_heads = block.attn.num_heads
scale = block.attn.scale
B, N, C = x.shape
qkv = qkv(x).reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
return attn
def forward_visualization(self, x):
# expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
x = x.unsqueeze(1)
x = x.transpose(2, 3)
B = x.shape[0]
x = self.v.patch_embed(x)
cls_tokens = self.v.cls_token.expand(B, -1, -1)
dist_token = self.v.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.v.pos_embed
x = self.v.pos_drop(x)
# save the attention map of each of 12 Transformer layer
att_list = []
for blk in self.v.blocks:
cur_att = self.get_att_map(blk, x)
att_list.append(cur_att)
x = blk(x)
return att_list
def make_features(wav_name, mel_bins, target_length=1024):
waveform, sr = torchaudio.load(wav_name)
# assert sr == 16000, 'input audio sampling rate must be 16kHz'
fbank = torchaudio.compliance.kaldi.fbank(
waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
n_frames = fbank.shape[0]
p = target_length - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
return fbank
def load_label(label_csv):
with open(label_csv, 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
return labels
def ASTpredict():
# Assume each input spectrogram has 1024 time frames
input_tdim = 1024
checkpoint_path = './ast_master/pretrained_models/audio_mdl.pth'
# now load the visualization model
ast_mdl = ASTModelVis(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
print(f'[*INFO] load checkpoint: {checkpoint_path}')
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
audio_model.load_state_dict(checkpoint)
audio_model = audio_model.to(torch.device("cuda:0"))
audio_model.eval()
# Load the AudioSet label set
label_csv = './ast_master/egs/audioset/data/class_labels_indices.csv' # label and indices for audioset data
labels = load_label(label_csv)
feats = make_features("./audio.flac", mel_bins=128) # shape(1024, 128)
feats_data = feats.expand(1, input_tdim, 128) # reshape the feature
feats_data = feats_data.to(torch.device("cuda:0"))
# do some masking of the input
#feats_data[:, :512, :] = 0.
# Make the prediction
with torch.no_grad():
with autocast():
output = audio_model.forward(feats_data)
output = torch.sigmoid(output)
result_output = output.data.cpu().numpy()[0]
sorted_indexes = np.argsort(result_output)[::-1]
# Print audio tagging top probabilities
print('Predice results:')
for k in range(10):
print('- {}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]], result_output[sorted_indexes[k]]))
#return the top 10 labels and their probabilities
top_labels_probs = {}
top_labels = {}
for k in range(10):
label = np.array(labels)[sorted_indexes[k]]
prob = result_output[sorted_indexes[k]]
top_labels[k]= label
top_labels_probs[k]= prob
return top_labels, top_labels_probs