Skip to content

Commit

Permalink
feat!: RunAsyncを使う
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Dec 8, 2024
1 parent 06ec811 commit 438f14f
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 144 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ zip = "0.6.3"

[workspace.dependencies.voicevox-ort]
git = "https://github.com/VOICEVOX/ort.git"
rev = "17f741301db0bb08da0eafe8a338e5efd8a4b5df"
rev = "09a9fe1619c1561efafc02f68f0bda4aad879771"

[workspace.dependencies.open_jtalk]
git = "https://github.com/VOICEVOX/open_jtalk-rs.git"
Expand Down
58 changes: 49 additions & 9 deletions crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod session_set;

use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, future::Future, ops::Index, sync::Arc};

use derive_new::new;
use duplicate::duplicate_item;
Expand All @@ -12,14 +12,39 @@ use ndarray::{Array, ArrayD, Dimension, ShapeError};
use thiserror::Error;

use crate::{
asyncs::{Async, BlockingThreadPool, SingleTasked},
devices::{DeviceSpec, GpuSpec},
StyleType, SupportedDevices,
};

pub(crate) trait AsyncExt: Async {
async fn run_session<R: InferenceRuntime>(
ctx: R::RunContext,
) -> anyhow::Result<Vec<OutputTensor>>;
}

impl AsyncExt for SingleTasked {
async fn run_session<R: InferenceRuntime>(
ctx: R::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
R::run(ctx)
}
}

impl AsyncExt for BlockingThreadPool {
async fn run_session<R: InferenceRuntime>(
ctx: R::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
R::run_async(ctx).await
}
}

pub(crate) trait InferenceRuntime: 'static {
// TODO: "session"とは何なのかを定め、ドキュメントを書く。`InferenceSessionSet`も同様。
type Session: Sized + Send + 'static;
type RunContext<'a>: From<&'a mut Self::Session> + PushInputTensor;
type Session;

// 本当は`From<'_ Self::Session>`としたいが、 rust-lang/rust#100013 がある
type RunContext: From<Arc<Self::Session>> + PushInputTensor;

/// 名前。
const DISPLAY_NAME: &'static str;
Expand All @@ -45,7 +70,11 @@ pub(crate) trait InferenceRuntime: 'static {
Vec<ParamInfo<OutputScalarKind>>,
)>;

fn run(ctx: Self::RunContext<'_>) -> anyhow::Result<Vec<OutputTensor>>;
fn run(ctx: Self::RunContext) -> anyhow::Result<Vec<OutputTensor>>;

fn run_async(
ctx: Self::RunContext,
) -> impl Future<Output = anyhow::Result<Vec<OutputTensor>>> + Send;
}

/// 共に扱われるべき推論操作の集合を示す。
Expand Down Expand Up @@ -101,15 +130,16 @@ pub(crate) trait InferenceInputSignature: Send + 'static {
const PARAM_INFOS: &'static [ParamInfo<InputScalarKind>];
fn make_run_context<R: InferenceRuntime>(
self,
sess: &mut R::Session,
) -> anyhow::Result<R::RunContext<'_>>;
sess: Arc<R::Session>,
) -> anyhow::Result<R::RunContext>;
}

pub(crate) trait InputScalar: Sized {
const KIND: InputScalarKind;

// TODO: `Array`ではなく`ArrayView`を取ることができるかもしれない
fn push_tensor_to_ctx(
name: &'static str,
tensor: Array<Self, impl Dimension + 'static>,
visitor: &mut impl PushInputTensor,
) -> anyhow::Result<()>;
Expand All @@ -124,10 +154,11 @@ impl InputScalar for T {
const KIND: InputScalarKind = KIND_VAL;

fn push_tensor_to_ctx(
name: &'static str,
tensor: Array<Self, impl Dimension + 'static>,
ctx: &mut impl PushInputTensor,
) -> anyhow::Result<()> {
ctx.push(tensor)
ctx.push(name, tensor)
}
}

Expand All @@ -141,8 +172,17 @@ pub(crate) enum InputScalarKind {
}

pub(crate) trait PushInputTensor {
fn push_int64(&mut self, tensor: Array<i64, impl Dimension + 'static>) -> anyhow::Result<()>;
fn push_float32(&mut self, tensor: Array<f32, impl Dimension + 'static>) -> anyhow::Result<()>;
fn push_int64(
&mut self,
name: &'static str,
tensor: Array<i64, impl Dimension + 'static>,
) -> anyhow::Result<()>;

fn push_float32(
&mut self,
name: &'static str,
tensor: Array<f32, impl Dimension + 'static>,
) -> anyhow::Result<()>;
}

/// 推論操作の出力シグネチャ。
Expand Down
86 changes: 50 additions & 36 deletions crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// }
// ```

use std::{fmt::Debug, vec};
use std::{fmt::Debug, sync::Arc, vec};

use anyhow::{anyhow, bail, ensure};
use duplicate::duplicate_item;
Expand All @@ -32,8 +32,8 @@ use super::super::{
};

impl InferenceRuntime for self::blocking::Onnxruntime {
type Session = ort::Session;
type RunContext<'a> = OnnxruntimeRunContext<'a>;
type Session = async_lock::Mutex<ort::Session>; // WASMでは`ort`を利用しないので、ここはasync-lockを用いてよいはず
type RunContext = OnnxruntimeRunContext;

const DISPLAY_NAME: &'static str = if cfg!(feature = "load-onnxruntime") {
"現在ロードされているONNX Runtime"
Expand Down Expand Up @@ -179,76 +179,90 @@ impl InferenceRuntime for self::blocking::Onnxruntime {
})
.collect::<anyhow::Result<_>>()?;

Ok((sess, input_param_infos, output_param_infos))
Ok((sess.into(), input_param_infos, output_param_infos))
}

fn run(
OnnxruntimeRunContext { sess, inputs }: OnnxruntimeRunContext<'_>,
OnnxruntimeRunContext { sess, inputs }: Self::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
let outputs = sess.run(&*inputs)?;

(0..outputs.len())
.map(|i| {
let output = &outputs[i];

let ValueType::Tensor { ty, .. } = output.dtype()? else {
bail!(
"unexpected output. currently `ONNX_TYPE_TENSOR` and \
`ONNX_TYPE_SPARSETENSOR` is supported",
);
};
extract_outputs(&sess.lock_blocking().run(inputs)?)
}

match ty {
TensorElementType::Float32 => {
let output = output.try_extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(output.into_owned()))
}
_ => bail!("unexpected output tensor element data type"),
}
})
.collect()
async fn run_async(
OnnxruntimeRunContext { sess, inputs }: Self::RunContext,
) -> anyhow::Result<Vec<OutputTensor>> {
extract_outputs(&sess.lock().await.run_async(inputs)?.await?)
}
}

pub(crate) struct OnnxruntimeRunContext<'sess> {
sess: &'sess ort::Session,
inputs: Vec<ort::SessionInputValue<'static>>,
pub(crate) struct OnnxruntimeRunContext {
sess: Arc<async_lock::Mutex<ort::Session>>,
inputs: Vec<(&'static str, ort::SessionInputValue<'static>)>,
}

impl OnnxruntimeRunContext<'_> {
impl OnnxruntimeRunContext {
fn push_input(
&mut self,
name: &'static str,
input: Array<
impl PrimitiveTensorElementType + Debug + Clone + 'static,
impl Dimension + 'static,
>,
) -> anyhow::Result<()> {
let input = ort::Value::from_array(input)?.into();
self.inputs.push(input);
self.inputs.push((name, input));
Ok(())
}
}

impl<'sess> From<&'sess mut ort::Session> for OnnxruntimeRunContext<'sess> {
fn from(sess: &'sess mut ort::Session) -> Self {
impl From<Arc<async_lock::Mutex<ort::Session>>> for OnnxruntimeRunContext {
fn from(sess: Arc<async_lock::Mutex<ort::Session>>) -> Self {
Self {
sess,
inputs: vec![],
}
}
}

impl PushInputTensor for OnnxruntimeRunContext<'_> {
impl PushInputTensor for OnnxruntimeRunContext {
#[duplicate_item(
method T;
[ push_int64 ] [ i64 ];
[ push_float32 ] [ f32 ];
)]
fn method(&mut self, tensor: Array<T, impl Dimension + 'static>) -> anyhow::Result<()> {
self.push_input(tensor)
fn method(
&mut self,
name: &'static str,
tensor: Array<T, impl Dimension + 'static>,
) -> anyhow::Result<()> {
self.push_input(name, tensor)
}
}

// FIXME: use ouroboros to reduce copies
fn extract_outputs(outputs: &ort::SessionOutputs<'_, '_>) -> anyhow::Result<Vec<OutputTensor>> {
(0..outputs.len())
.map(|i| {
let output = &outputs[i];

let ValueType::Tensor { ty, .. } = output.dtype()? else {
bail!(
"unexpected output. currently `ONNX_TYPE_TENSOR` and `ONNX_TYPE_SPARSETENSOR`
is supported",
);
};

match ty {
TensorElementType::Float32 => {
let output = output.try_extract_tensor::<f32>()?;
Ok(OutputTensor::Float32(output.into_owned()))
}
_ => bail!("unexpected output tensor element data type"),
}
})
.collect()
}

pub(crate) mod blocking {
use ort::EnvHandle;
use ref_cast::{ref_cast_custom, RefCastCustom};
Expand Down
19 changes: 11 additions & 8 deletions crates/voicevox_core/src/infer/session_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{
};

pub(crate) struct InferenceSessionSet<R: InferenceRuntime, D: InferenceDomain>(
EnumMap<D::Operation, Arc<std::sync::Mutex<R::Session>>>,
EnumMap<D::Operation, Arc<R::Session>>,
);

impl<R: InferenceRuntime, D: InferenceDomain> InferenceSessionSet<R, D> {
Expand All @@ -33,7 +33,7 @@ impl<R: InferenceRuntime, D: InferenceDomain> InferenceSessionSet<R, D> {
check_param_infos(expected_input_param_infos, &actual_input_param_infos)?;
check_param_infos(expected_output_param_infos, &actual_output_param_infos)?;

Ok((op.into_usize(), std::sync::Mutex::new(sess).into()))
Ok((op.into_usize(), sess.into()))
})
.collect::<anyhow::Result<HashMap<_, _>>>()?;

Expand Down Expand Up @@ -84,18 +84,21 @@ impl<R: InferenceRuntime, D: InferenceDomain> InferenceSessionSet<R, D> {
}

pub(crate) struct InferenceSessionCell<R: InferenceRuntime, I> {
inner: Arc<std::sync::Mutex<R::Session>>,
inner: Arc<R::Session>,
marker: PhantomData<fn(I)>,
}

impl<R: InferenceRuntime, I: InferenceInputSignature> InferenceSessionCell<R, I> {
pub(crate) fn run(
pub(crate) async fn run<A: super::AsyncExt>(
self,
input: I,
) -> crate::Result<<I::Signature as InferenceSignature>::Output> {
let inner = &mut self.inner.lock().unwrap();
(|| R::run(input.make_run_context::<R>(inner)?)?.try_into())()
.map_err(ErrorRepr::RunModel)
.map_err(Into::into)
async {
let ctx = input.make_run_context::<R>(self.inner.clone())?;
A::run_session::<R>(ctx).await?.try_into()
}
.await
.map_err(ErrorRepr::RunModel)
.map_err(Into::into)
}
}
6 changes: 4 additions & 2 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use itertools::iproduct;
use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
self,
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain},
session_set::{InferenceSessionCell, InferenceSessionSet},
InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
Expand Down Expand Up @@ -104,17 +105,18 @@ impl<R: InferenceRuntime> Status<R> {
/// # Panics
///
/// `self`が`model_id`を含んでいないとき、パニックする。
pub(crate) fn run_session<I>(
pub(crate) async fn run_session<A, I>(
&self,
model_id: VoiceModelId,
input: I,
) -> Result<<I::Signature as InferenceSignature>::Output>
where
A: infer::AsyncExt,
I: InferenceInputSignature,
<I::Signature as InferenceSignature>::Domain: InferenceDomainExt,
{
let sess = self.loaded_models.lock().unwrap().get(model_id);
sess.run(input)
sess.run::<A>(input).await
}
}

Expand Down
Loading

0 comments on commit 438f14f

Please sign in to comment.