forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.rs
362 lines (333 loc) · 12.2 KB
/
model.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
}
impl Config {
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
vocab_size: 51864,
max_source_positions: 1500,
d_model: 384,
encoder_attention_heads: 6,
encoder_layers: 4,
max_target_positions: 448,
// n_text_state: 384,
decoder_attention_heads: 6,
decoder_layers: 4,
}
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
let bias = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
Ok(Conv1d::new(weight, None, config))
}
struct Dropout {
pr: f64,
}
impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
}
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, 1e-5))
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
struct MultiHeadAttention {
query: Linear,
key: Linear,
value: Linear,
out: Linear,
n_head: usize,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
Ok(Self {
query,
key,
value,
out,
n_head,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
}
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = q.matmul(&k)?;
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
cross_attn,
mlp_linear1,
mlp_linear2,
mlp_ln,
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
}
}
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
Ok(sincos)
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
positional_embedding,
blocks,
ln_post,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None, None)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
})
}
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa), Some(&self.mask))?;
}
let x = self.ln.forward(&x)?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,
pub config: Config,
}
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
Ok(Self {
encoder,
decoder,
config,
})
}
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
let enc = self.encoder.forward(mel)?;
let dec = self.decoder.forward(tokens, &enc)?;
Ok(dec)
}
}