Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cargo fix and readme updates #21

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

**/.DS_Store
**/.DS_Store

# Python virtual environment
python/venv/
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ This project is licensed under the terms of the MIT license.

The OpenAI Whisper models that have been converted to work in burn are available in the whisper-burn space on Hugging Face. You can find them at [https://huggingface.co/Gadersd/whisper-burn](https://huggingface.co/Gadersd/whisper-burn).

If you have a custom fine-tuned model you can easily convert it to burn's format. Here is an example of converting OpenAI's tiny en model. The tinygrad dependency of the dump.py script should be installed from source not with pip.
If you have a custom fine-tuned model you can easily convert it to burn's format. Here is an example of converting OpenAI's tiny en model:

```
# Download the tiny_en tokenizer
wget https://huggingface.co/Gadersd/whisper-burn/resolve/main/tiny_en/tokenizer.json

cd python
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt

wget https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt
python3 dump.py tiny.en.pt tiny_en
mv tiny_en ../
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy
typing
torch
tinygrad
git+https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets@main
2 changes: 1 addition & 1 deletion src/beam.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::cmp::Ordering;


#[derive(Clone)]
pub struct BeamNode<T: Clone> {
Expand Down
8 changes: 2 additions & 6 deletions src/bin/convert/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ use whisper::model::{load::*, *};

use burn::{
module::Module,
tensor::{
self,
backend::{self, Backend},
Int, Tensor,
},
tensor::backend::Backend,
};

use burn_tch::{TchBackend, TchDevice};
Expand All @@ -30,7 +26,7 @@ fn main() {
};

type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
let _device = TchDevice::Cpu;

let (whisper, whisper_config): (Whisper<Backend>, WhisperConfig) =
match load_whisper(&model_name) {
Expand Down
32 changes: 11 additions & 21 deletions src/bin/transcribe/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::HashMap;
use std::iter;

use whisper::helper::*;
use whisper::model::*;
use whisper::{token, token::Language};
use whisper::token::Language;
use whisper::transcribe::waveform_to_text;
use whisper::token::Gpt2Tokenizer;

use burn::record::{DefaultRecorder, Recorder, RecorderError};

use std::{env, fs, process};

use strum::IntoEnumIterator;

Expand All @@ -19,23 +20,19 @@ cfg_if::cfg_if! {
use burn::{
config::Config,
module::Module,
tensor::{
self,
backend::{self, Backend},
Data, Float, Int, Tensor,
},
tensor::backend::Backend,
};

use hound::{self, SampleFormat};

fn load_audio_waveform<B: Backend>(filename: &str) -> hound::Result<(Vec<f32>, usize)> {
let mut reader = hound::WavReader::open(filename)?;
let reader = hound::WavReader::open(filename)?;
let spec = reader.spec();

let duration = reader.duration() as usize;
let _duration = reader.duration() as usize;
let channels = spec.channels as usize;
let sample_rate = spec.sample_rate as usize;
let bits_per_sample = spec.bits_per_sample;
let _bits_per_sample = spec.bits_per_sample;
let sample_format = spec.sample_format;

assert_eq!(sample_rate, 16000, "The audio sample rate must be 16k.");
Expand All @@ -54,12 +51,6 @@ fn load_audio_waveform<B: Backend>(filename: &str) -> hound::Result<(Vec<f32>, u
return Ok((floats, sample_rate));
}

use num_traits::ToPrimitive;
use whisper::audio::prep_audio;
use whisper::token::{Gpt2Tokenizer, SpecialToken};

use burn::record::{DefaultRecorder, Recorder, RecorderError};

fn load_whisper_model_file<B: Backend>(
config: &WhisperConfig,
filename: &str,
Expand All @@ -69,7 +60,6 @@ fn load_whisper_model_file<B: Backend>(
.map(|record| config.init().load_record(record))
}

use std::{env, fs, process};

fn main() {
cfg_if::cfg_if! {
Expand Down Expand Up @@ -142,7 +132,7 @@ fn main() {

let whisper = whisper.to_device(&device);

let (text, tokens) = match waveform_to_text(&whisper, &bpe, lang, waveform, sample_rate) {
let (text, _tokens) = match waveform_to_text(&whisper, &bpe, lang, waveform, sample_rate) {
Ok((text, tokens)) => (text, tokens),
Err(e) => {
eprintln!("Error during transcription: {}", e);
Expand Down
2 changes: 1 addition & 1 deletion src/helper.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::tensor::{
activation::relu, backend::Backend, BasicOps, Bool, Element, Float, Int, Numeric, Tensor,
activation::relu, backend::Backend, BasicOps, Element, Numeric, Tensor,
TensorKind,
};

Expand Down
2 changes: 1 addition & 1 deletion src/model/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::{
conv::{Conv1d, Conv1dConfig, Conv1dRecord},
PaddingConfig1d,
},
tensor::{activation::relu, backend::Backend, Bool, Int, Tensor},
tensor::{backend::Backend, Tensor},
};

use super::*;
Expand Down
4 changes: 2 additions & 2 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn::{
module::{Module, Param},
nn::{
self,
conv::{Conv1d, Conv1dConfig, Conv1dRecord},
conv::{Conv1d, Conv1dConfig},
PaddingConfig1d,
},
tensor::{activation::softmax, backend::Backend, module::embedding, Distribution, Int, Tensor},
Expand Down Expand Up @@ -129,7 +129,7 @@ pub struct TextDecoder<B: Backend> {

impl<B: Backend> TextDecoder<B> {
fn forward(&self, x: Tensor<B, 2, Int>, xa: Tensor<B, 3>) -> Tensor<B, 3> {
let [n_batch, seq_len] = x.dims();
let [_n_batch, seq_len] = x.dims();

assert!(
seq_len <= self.n_text_ctx,
Expand Down
4 changes: 2 additions & 2 deletions src/token.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::ser::StdError;
use std::result;

use tokenizers::{AddedToken, Tokenizer};
use tokenizers::{AddedToken};

pub type Result<T> = result::Result<T, Box<(dyn StdError + Send + Sync + 'static)>>;

Expand All @@ -12,7 +12,7 @@ pub struct Gpt2Tokenizer {
impl Gpt2Tokenizer {
pub fn new() -> Result<Self> {
//let mut tokenizer = tokenizers::Tokenizer::from_pretrained("gpt2", None)?;
let mut tokenizer = tokenizers::Tokenizer::from_file("tokenizer.json")?;
let tokenizer = tokenizers::Tokenizer::from_file("tokenizer.json")?;
//tokenizer.add_special_tokens(&construct_special_tokens());

Ok(Self { tokenizer })
Expand Down
23 changes: 11 additions & 12 deletions src/transcribe.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use crate::audio::{max_waveform_samples, prep_audio};
use crate::helper::*;

use crate::model::*;
use crate::token::{self, *};
use crate::beam;

use num_traits::ToPrimitive;


use std::iter;

use burn::{
config::Config,
module::Module,
tensor::{
self,
backend::{self, Backend},
Data, Float, Int, Tensor,
backend::{Backend},
Data, Tensor,
ElementConversion,
activation::log_softmax,
},
Expand Down Expand Up @@ -50,7 +49,7 @@ pub fn waveform_to_text<B: Backend>(
prev_normal_tokens.reverse();
//println!("Prev tokens: {:?} {}", prev_normal_tokens, bpe.decode(&prev_normal_tokens[..], false)?);

let (new_text, new_tokens) =
let (_new_text, new_tokens) =
mels_to_text(whisper, bpe, lang, mel, &prev_normal_tokens[..], padding)?;

if let Some((prev_index, curr_index)) =
Expand Down Expand Up @@ -156,9 +155,9 @@ fn mels_to_text<B: Backend>(
let device = mels.device();

let n_ctx_max_encoder = whisper.encoder_ctx_size();
let n_ctx_max_decoder = whisper.decoder_ctx_size();
let _n_ctx_max_decoder = whisper.decoder_ctx_size();

let [n_channel, n_mel, n_ctx] = mels.dims();
let [_n_channel, n_mel, n_ctx] = mels.dims();
if n_ctx + padding > n_ctx_max_encoder {
println!(
"Audio has length of {} which exceeds maximum length {}. It will be clipped.",
Expand All @@ -180,7 +179,7 @@ fn mels_to_text<B: Backend>(
let transcription_token = bpe.special_token(SpecialToken::Transcribe).unwrap();
let start_of_prev_token = bpe.special_token(SpecialToken::StartofPrev).unwrap();
let lang_token = bpe.special_token(SpecialToken::Language(lang)).unwrap();
let first_timestamp_token = bpe.special_token(SpecialToken::Timestamp(0.0)).unwrap();
let _first_timestamp_token = bpe.special_token(SpecialToken::Timestamp(0.0)).unwrap();
let end_token = bpe.special_token(SpecialToken::EndofText).unwrap();
let notimestamp = bpe.special_token(SpecialToken::NoTimeStamps).unwrap();

Expand All @@ -192,7 +191,7 @@ fn mels_to_text<B: Backend>(
.chain(iter::once(bpe.special_token(SpecialToken::Timestamp(0.0)).unwrap()))
.collect();*/

let mut initial_tokens = if prev_nonspecial_tokens.len() > 0 {
let _initial_tokens = if prev_nonspecial_tokens.len() > 0 {
iter::once(start_of_prev_token).chain(prev_nonspecial_tokens.iter().cloned()).collect()
} else {
Vec::new()
Expand Down Expand Up @@ -241,7 +240,7 @@ fn mels_to_text<B: Backend>(
};

let vocab_size = bpe.vocab_size();
let mut special_tokens_maskout: Vec<f32> = (0..vocab_size).into_iter().map(|token| if bpe.is_special(token) {neg_infty} else {0.0}).collect();
let special_tokens_maskout: Vec<f32> = (0..vocab_size).into_iter().map(|token| if bpe.is_special(token) {neg_infty} else {0.0}).collect();
//special_tokens_maskout[end_token] = 1.0;

let special_tokens_maskout = Tensor::from_data(Data::new(
Expand Down Expand Up @@ -275,7 +274,7 @@ fn mels_to_text<B: Backend>(
};
let log_probs = log_softmax(logits, 2);

let [n_batch, n_token, n_dict] = log_probs.dims();
let [_n_batch, _n_token, _n_dict] = log_probs.dims();
let beam_log_probs = beams.iter().enumerate().map(|(i, beam)| {
let batch = i;
let token_index = beam.seq.len() - 1;
Expand Down