Skip to content

Commit

Permalink
change: VoiceModelVoiceModelFile
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Sep 16, 2024
1 parent b8118b9 commit f32872e
Show file tree
Hide file tree
Showing 48 changed files with 653 additions and 303 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ anstyle-query = "1.0.0"
anyhow = "1.0.65"
assert_cmd = "2.0.8"
async-fs = "2.1.2"
async-lock = "3.4.0"
async_zip = "=0.0.16"
bindgen = "0.69.4"
binstall-tar = "0.4.39"
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ link-onnxruntime = []

[dependencies]
anyhow.workspace = true
async-fs.workspace = true
async-fs.workspace = true # 今これを使っている箇所はどこにも無いが、`UserDict`にはこれを使った方がよいはず
async-lock.workspace = true
async_zip = { workspace = true, features = ["deflate"] }
blocking.workspace = true
camino.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/__internal/doctest_fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn synthesizer_with_sample_voice_model(
},
)?;

let model = &crate::nonblocking::VoiceModel::from_path(voice_model_path).await?;
let model = &crate::nonblocking::VoiceModelFile::open(voice_model_path).await?;
syntesizer.load_voice_model(model).await?;

Ok(syntesizer)
Expand Down
147 changes: 121 additions & 26 deletions crates/voicevox_core/src/asyncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,24 @@
use std::{
io::{self, Read as _, Seek as _, SeekFrom},
ops::DerefMut,
path::Path,
pin::Pin,
task::{self, Poll},
};

use blocking::Unblock;
use futures_io::{AsyncRead, AsyncSeek};
use futures_util::ready;

pub(crate) trait Async: 'static {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin>;
type Mutex<T: Send + Sync + Unpin>: Mutex<T>;
type RoFile: AsyncRead + AsyncSeek + Send + Sync + Unpin;
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile>;
}

pub(crate) trait Mutex<T>: From<T> + Send + Sync + Unpin {
async fn lock(&self) -> impl DerefMut<Target = T>;
}

/// エグゼキュータが非同期タスクの並行実行をしないことを仮定する、[`Async`]の実装。
Expand All @@ -39,30 +48,47 @@ pub(crate) trait Async: 'static {
pub(crate) enum SingleTasked {}

impl Async for SingleTasked {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin> {
return std::fs::File::open(path).map(BlockingFile);

struct BlockingFile(std::fs::File);

impl AsyncRead for BlockingFile {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.read(buf))
}
}
type Mutex<T: Send + Sync + Unpin> = StdMutex<T>;
type RoFile = StdFile;

impl AsyncSeek for BlockingFile {
fn poll_seek(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
Poll::Ready(self.0.seek(pos))
}
}
async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
std::fs::File::open(path).map(StdFile)
}
}

pub(crate) struct StdMutex<T>(std::sync::Mutex<T>);

impl<T> From<T> for StdMutex<T> {
fn from(inner: T) -> Self {
Self(inner.into())
}
}

impl<T: Send + Sync + Unpin> Mutex<T> for StdMutex<T> {
async fn lock(&self) -> impl DerefMut<Target = T> {
self.0.lock().unwrap_or_else(|e| panic!("{e}"))
}
}

pub(crate) struct StdFile(std::fs::File);

impl AsyncRead for StdFile {
fn poll_read(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.read(buf))
}
}

impl AsyncSeek for StdFile {
fn poll_seek(
mut self: Pin<&mut Self>,
_: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
Poll::Ready(self.0.seek(pos))
}
}

Expand All @@ -74,7 +100,76 @@ impl Async for SingleTasked {
pub(crate) enum BlockingThreadPool {}

impl Async for BlockingThreadPool {
async fn open_file(path: impl AsRef<Path>) -> io::Result<impl AsyncRead + AsyncSeek + Unpin> {
async_fs::File::open(path).await
type Mutex<T: Send + Sync + Unpin> = async_lock::Mutex<T>;
type RoFile = AsyncRoFile;

async fn open_file_ro(path: impl AsRef<Path>) -> io::Result<Self::RoFile> {
AsyncRoFile::open(path).await
}
}

impl<T: Send + Sync + Unpin> Mutex<T> for async_lock::Mutex<T> {
async fn lock(&self) -> impl DerefMut<Target = T> {
self.lock().await
}
}

// TODO: `async_fs::File::into_std_file`みたいなのがあればこんなの↓は作らなくていいはず。PR出す?
pub(crate) struct AsyncRoFile {
// `poll_read`と`poll_seek`しかしない
unblock: Unblock<std::fs::File>,

// async-fsの実装がやっているように「正しい」シーク位置を保持する。ただしファイルはパイプではな
// いことがわかっているため smol-rs/async-fs#4 は考えない
real_seek_pos: Option<u64>,
}

impl AsyncRoFile {
async fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref().to_owned();
let unblock = Unblock::new(blocking::unblock(|| std::fs::File::open(path)).await?);
Ok(Self {
unblock,
real_seek_pos: None,
})
}

pub(crate) async fn close(self) {
let file = self.unblock.into_inner().await;
blocking::unblock(|| drop(file)).await;
}
}

impl AsyncRead for AsyncRoFile {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if self.real_seek_pos.is_none() {
self.real_seek_pos = Some(ready!(
Pin::new(&mut self.unblock).poll_seek(cx, SeekFrom::Current(0))
)?);
}
let n = ready!(Pin::new(&mut self.unblock).poll_read(cx, buf))?;
*self.real_seek_pos.as_mut().expect("should be present") += n as u64;
Poll::Ready(Ok(n))
}
}

impl AsyncSeek for AsyncRoFile {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
// async-fsの実装がやっているような"reposition"を行う。
// https://github.com/smol-rs/async-fs/issues/2#issuecomment-675595170
if let Some(real_seek_pos) = self.real_seek_pos {
ready!(Pin::new(&mut self.unblock).poll_seek(cx, SeekFrom::Start(real_seek_pos)))?;
}
self.real_seek_pos = None;

Pin::new(&mut self.unblock).poll_seek(cx, pos)
}
}
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
pub use crate::{
engine::open_jtalk::blocking::OpenJtalk, infer::runtimes::onnxruntime::blocking::Onnxruntime,
synthesizer::blocking::Synthesizer, user_dict::dict::blocking::UserDict,
voice_model::blocking::VoiceModel,
voice_model::blocking::VoiceModelFile,
};

pub mod onnxruntime {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/engine/open_jtalk.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: `VoiceModel`のように、次のような設計にする。
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/infer/runtimes/onnxruntime.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: `VoiceModel`のように、次のような設計にする。
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
Expand Down
2 changes: 1 addition & 1 deletion crates/voicevox_core/src/nonblocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
pub use crate::{
engine::open_jtalk::nonblocking::OpenJtalk,
infer::runtimes::onnxruntime::nonblocking::Onnxruntime, synthesizer::nonblocking::Synthesizer,
user_dict::dict::nonblocking::UserDict, voice_model::nonblocking::VoiceModel,
user_dict::dict::nonblocking::UserDict, voice_model::nonblocking::VoiceModelFile,
};

pub mod onnxruntime {
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ mod tests {
talk: enum_map!(_ => InferenceSessionOptions::new(0, DeviceSpec::Cpu)),
},
);
let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
let model_contents = &model.read_inference_models().await.unwrap();
let result = status.insert_model(model.header(), model_contents);
assert_debug_fmt_eq!(Ok(()), result);
Expand All @@ -424,7 +424,7 @@ mod tests {
talk: enum_map!(_ => InferenceSessionOptions::new(0, DeviceSpec::Cpu)),
},
);
let vvm = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let vvm = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
let model_header = vvm.header();
let model_contents = &vvm.read_inference_models().await.unwrap();
assert!(
Expand Down
31 changes: 17 additions & 14 deletions crates/voicevox_core/src/synthesizer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// TODO: `VoiceModel`のように、次のような設計にする。
// TODO: `VoiceModelFile`のように、次のような設計にする。
//
// ```
// pub(crate) mod blocking {
Expand Down Expand Up @@ -235,7 +235,7 @@ pub(crate) mod blocking {
}

/// 音声モデルを読み込む。
pub fn load_voice_model(&self, model: &crate::blocking::VoiceModel) -> Result<()> {
pub fn load_voice_model(&self, model: &crate::blocking::VoiceModelFile) -> Result<()> {
let model_bytes = &model.read_inference_models()?;
self.status.insert_model(model.header(), model_bytes)
}
Expand Down Expand Up @@ -1181,7 +1181,10 @@ pub(crate) mod nonblocking {
self.0.is_gpu_mode()
}

pub async fn load_voice_model(&self, model: &crate::nonblocking::VoiceModel) -> Result<()> {
pub async fn load_voice_model(
&self,
model: &crate::nonblocking::VoiceModelFile,
) -> Result<()> {
let model_bytes = &model.read_inference_models().await?;
self.0.status.insert_model(model.header(), model_bytes)
}
Expand Down Expand Up @@ -1351,7 +1354,7 @@ mod tests {
.unwrap();

let result = syntesizer
.load_voice_model(&crate::nonblocking::VoiceModel::sample().await.unwrap())
.load_voice_model(&crate::nonblocking::VoiceModelFile::sample().await.unwrap())
.await;

assert_debug_fmt_eq!(
Expand Down Expand Up @@ -1399,7 +1402,7 @@ mod tests {
"expected is_model_loaded to return false, but got true",
);
syntesizer
.load_voice_model(&crate::nonblocking::VoiceModel::sample().await.unwrap())
.load_voice_model(&crate::nonblocking::VoiceModelFile::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1427,7 +1430,7 @@ mod tests {
.unwrap();

syntesizer
.load_voice_model(&crate::nonblocking::VoiceModel::sample().await.unwrap())
.load_voice_model(&crate::nonblocking::VoiceModelFile::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1460,7 +1463,7 @@ mod tests {
)
.unwrap();
syntesizer
.load_voice_model(&crate::nonblocking::VoiceModel::sample().await.unwrap())
.load_voice_model(&crate::nonblocking::VoiceModelFile::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1502,7 +1505,7 @@ mod tests {
)
.unwrap();
syntesizer
.load_voice_model(&crate::nonblocking::VoiceModel::sample().await.unwrap())
.load_voice_model(&crate::nonblocking::VoiceModelFile::sample().await.unwrap())
.await
.unwrap();

Expand Down Expand Up @@ -1599,7 +1602,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let query = match input {
Expand Down Expand Up @@ -1670,7 +1673,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let accent_phrases = match input {
Expand Down Expand Up @@ -1738,7 +1741,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let accent_phrases = syntesizer
Expand Down Expand Up @@ -1801,7 +1804,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let accent_phrases = syntesizer
Expand Down Expand Up @@ -1842,7 +1845,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let accent_phrases = syntesizer
Expand Down Expand Up @@ -1883,7 +1886,7 @@ mod tests {
)
.unwrap();

let model = &crate::nonblocking::VoiceModel::sample().await.unwrap();
let model = &crate::nonblocking::VoiceModelFile::sample().await.unwrap();
syntesizer.load_voice_model(model).await.unwrap();

let accent_phrases = syntesizer
Expand Down
Loading

0 comments on commit f32872e

Please sign in to comment.