Skip to content

Commit

Permalink
Add: ttsまでできるように
Browse files Browse the repository at this point in the history
  • Loading branch information
sevenc-nanashi committed Mar 16, 2024
1 parent e7fdda5 commit 4022167
Show file tree
Hide file tree
Showing 12 changed files with 567 additions and 64 deletions.
6 changes: 3 additions & 3 deletions crates/voicevox_core/src/infer/runtimes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// #[cfg(not(target_family = "wasm"))]
// mod onnxruntime;
// #[cfg(target_family = "wasm")]
#[cfg(not(target_family = "wasm"))]
mod onnxruntime;
#[cfg(target_family = "wasm")]
#[path = "runtimes/onnxruntime_wasm.rs"]
mod onnxruntime;

Expand Down
169 changes: 153 additions & 16 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime_wasm.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::any::Any;
use std::mem::ManuallyDrop;
use std::sync::Arc;
#![allow(unsafe_code)]
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::sync::Mutex;
use std::{fmt::Debug, vec};

use anyhow::anyhow;
use duplicate::duplicate_item;
use ndarray::{Array, Dimension};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use tracing::info;

use crate::devices::SupportedDevices;

Expand All @@ -15,6 +18,34 @@ use super::super::{
OutputScalarKind, OutputTensor, ParamInfo, PushInputTensor,
};

static RESULTS: Lazy<Mutex<HashMap<String, String>>> = Lazy::new(|| Mutex::new(HashMap::new()));

#[derive(Debug, Deserialize)]
struct SessionNewResult {
handle: String,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", content = "payload", rename_all = "camelCase")]
enum JsResult<T> {
Ok(T),
Err(String),
}

extern "C" {
fn onnxruntime_inference_session_new(
model: *const u8,
model_len: usize,
use_gpu: bool,
callback: extern "C" fn(*const u8, *const u8) -> (),
) -> *const u8;
fn onnxruntime_inference_session_run(
handle: *const u8,
inputs: *const u8,
callback: extern "C" fn(*const u8, *const u8) -> (),
) -> *const u8;
fn emscripten_sleep(millis: i32);
}

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub(crate) enum Onnxruntime {}

Expand All @@ -25,7 +56,7 @@ impl InferenceRuntime for Onnxruntime {
fn supported_devices() -> crate::Result<SupportedDevices> {
Ok(SupportedDevices {
cpu: true,
cuda: false,
cuda: true,
dml: false,
})
}
Expand All @@ -38,39 +69,145 @@ impl InferenceRuntime for Onnxruntime {
Vec<ParamInfo<InputScalarKind>>,
Vec<ParamInfo<OutputScalarKind>>,
)> {
todo!()
unsafe {
info!("creating new session");
let model = model()?;
let model_len = model.len();
let cpu_num_threads = options.cpu_num_threads as usize;
let use_gpu = options.use_gpu;
let nonce =
onnxruntime_inference_session_new(model.as_ptr(), model_len, use_gpu, js_callback);

let nonce = CStr::from_ptr(nonce as *const i8)
.to_str()
.map_err(|err| anyhow!(err))?
.to_string();
info!("nonce: {}", nonce);

let result = loop {
let result = RESULTS.lock().expect("mutex poisoned").remove(&nonce);
if let Some(result) = result {
break result;
}
emscripten_sleep(10);
};

let result: JsResult<SessionNewResult> = serde_json::from_str(&result)?;
let result = match result {
JsResult::Ok(result) => result,
JsResult::Err(err) => return Err(anyhow!(err)),
};

let handle = result.handle;
let session = OnnxruntimeSession { handle };
Ok((session, vec![], vec![]))
}
}

fn run(ctx: OnnxruntimeRunContext<'_>) -> anyhow::Result<Vec<OutputTensor>> {
todo!()
unsafe {
let handle_cstr = CString::new(ctx.session.handle.clone())?;
let inputs = serde_json::to_string(&ctx.inputs)?;
let inputs_cstr = CString::new(inputs)?;
let nonce = onnxruntime_inference_session_run(
handle_cstr.into_raw() as _,
inputs_cstr.into_raw() as _,
js_callback,
);
let nonce = CStr::from_ptr(nonce as *const i8)
.to_str()
.map_err(|err| anyhow!(err))?
.to_string();

let result = loop {
let result = RESULTS.lock().expect("mutex poisoned").remove(&nonce);
if let Some(result) = result {
break result;
}
emscripten_sleep(10);
};
let result: JsResult<Vec<Tensor>> = serde_json::from_str(&result)?;
let result = match result {
JsResult::Ok(result) => result,
JsResult::Err(err) => return Err(anyhow!(err)),
};

Ok(result
.into_iter()
.map(|tensor| {
let shape = tensor.shape;
match tensor.data {
TensorData::Int64(data) => {
unimplemented!()
}
TensorData::Float32(data) => {
OutputTensor::Float32(Array::from_shape_vec(shape, data).unwrap())
}
}
})
.collect())
}
}
}

pub(crate) struct OnnxruntimeSession {}
extern "C" fn js_callback(nonce: *const u8, result: *const u8) {
let nonce = unsafe { CStr::from_ptr(nonce as *const i8) }
.to_str()
.expect("invalid handle")
.to_string();
let result = unsafe { CStr::from_ptr(result as *const i8) }
.to_str()
.expect("invalid result")
.to_string();
info!("callback called with nonce: {}", nonce);
RESULTS
.lock()
.expect("mutex poisoned")
.insert(nonce, result);
}

pub(crate) struct OnnxruntimeSession {
handle: String,
}

impl Drop for OnnxruntimeSession {
fn drop(&mut self) {
todo!()
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "kind", content = "array", rename_all = "camelCase")]
pub(crate) enum TensorData {
Int64(Vec<i64>),
Float32(Vec<f32>),
}
#[derive(Serialize, Deserialize)]
pub(crate) struct Tensor {
data: TensorData,
shape: Vec<usize>,
}

pub(crate) struct OnnxruntimeRunContext<'sess> {
session: &'sess mut OnnxruntimeSession,
inputs: Vec<Tensor>,
}

impl<'sess> From<&'sess mut OnnxruntimeSession> for OnnxruntimeRunContext<'sess> {
fn from(sess: &'sess mut OnnxruntimeSession) -> Self {
todo!()
Self {
session: sess,
inputs: vec![],
}
}
}

impl PushInputTensor for OnnxruntimeRunContext<'_> {
#[duplicate_item(
method T;
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
method T kind_item;
[ push_int64 ] [ i64 ] [ Int64 ];
[ push_float32 ] [ f32 ] [ Float32 ];
)]
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) {
todo!()
let shape = tensor.shape().to_vec();
let tensor_vec = tensor.into_raw_vec();
self.inputs.push(Tensor {
data: TensorData::kind_item(tensor_vec),
shape,
});
}
}
10 changes: 10 additions & 0 deletions crates/voicevox_core/src/infer/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ impl<R: InferenceRuntime, D: InferenceDomain> SessionSet<R, D> {
sessions.remove(&k.into_usize()).expect("should exist")
})));

#[cfg(target_family = "wasm")]
fn check_param_infos<D: PartialEq + Display>(
_expected: &[ParamInfo<D>],
_actual: &[ParamInfo<D>],
) -> anyhow::Result<()> {
// onnxruntime-web ではパラメータ情報を取れないので、チェックをスキップする
// ref: https://github.com/microsoft/onnxruntime/discussions/17682
Ok(())
}
#[cfg(not(target_family = "wasm"))]
fn check_param_infos<D: PartialEq + Display>(
expected: &[ParamInfo<D>],
actual: &[ParamInfo<D>],
Expand Down
Loading

0 comments on commit 4022167

Please sign in to comment.