Skip to content

Commit

Permalink
refactor: VVMマニフェストで#[serde(flatten)]を活用
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Dec 11, 2024
1 parent 0edf940 commit 3168837
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 138 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ derive_more.workspace = true
duplicate.workspace = true
easy-ext.workspace = true
educe.workspace = true
enum-map.workspace = true
enum-map = { workspace = true, features = ["serde"] }
fs-err.workspace = true
futures-io.workspace = true
futures-lite.workspace = true
Expand Down
46 changes: 33 additions & 13 deletions crates/voicevox_core/src/manifest.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::{
collections::BTreeMap,
fmt::{self, Display},
ops::Index,
sync::Arc,
};

use derive_getters::Getters;
use derive_more::Deref;
use derive_new::new;
use macros::IndexForFields;
use enum_map::{Enum, EnumMap};
use serde::{de, Deserialize, Deserializer, Serialize};
use serde_with::{serde_as, DisplayFromStr};

Expand Down Expand Up @@ -81,24 +82,43 @@ pub struct Manifest {

pub(crate) type ManifestDomains = inference_domain_map_values!(for<D> Option<D::Manifest>);

#[derive(Deserialize, IndexForFields)]
#[derive(Deserialize)]
#[cfg_attr(test, derive(Default))]
#[index_for_fields(TalkOperation)]
pub(crate) struct TalkManifest {
#[index_for_fields(TalkOperation::PredictDuration)]
pub(crate) predict_duration_filename: Arc<str>,
#[serde(flatten)]
filenames: EnumMap<TalkOperationFilenameKey, Arc<str>>,

#[index_for_fields(TalkOperation::PredictIntonation)]
pub(crate) predict_intonation_filename: Arc<str>,
#[serde(default)]
pub(crate) style_id_to_inner_voice_id: StyleIdToInnerVoiceId,
}

#[index_for_fields(TalkOperation::GenerateFullIntermediate)]
pub(crate) generate_full_intermediate_filename: Arc<str>,
// TODO: #825 では`TalkOperation`と統合する。`Index`の実装もderive_moreで委譲する
#[derive(Enum, Deserialize)]
pub(crate) enum TalkOperationFilenameKey {
#[serde(rename = "predict_duration_filename")]
PredictDuration,
#[serde(rename = "predict_intonation_filename")]
PredictIntonation,
#[serde(rename = "generate_full_intermediate_filename")]
GenerateFullIntermediate,
#[serde(rename = "render_audio_segment_filename")]
RenderAudioSegment,
}

#[index_for_fields(TalkOperation::RenderAudioSegment)]
pub(crate) render_audio_segment_filename: Arc<str>,
impl Index<TalkOperation> for TalkManifest {
type Output = Arc<str>;

#[serde(default)]
pub(crate) style_id_to_inner_voice_id: StyleIdToInnerVoiceId,
fn index(&self, index: TalkOperation) -> &Self::Output {
let key = match index {
TalkOperation::PredictDuration => TalkOperationFilenameKey::PredictDuration,
TalkOperation::PredictIntonation => TalkOperationFilenameKey::PredictIntonation,
TalkOperation::GenerateFullIntermediate => {
TalkOperationFilenameKey::GenerateFullIntermediate
}
TalkOperation::RenderAudioSegment => TalkOperationFilenameKey::RenderAudioSegment,
};
&self.filenames[key]
}
}

#[serde_as]
Expand Down
40 changes: 21 additions & 19 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
use anyhow::{anyhow, Context as _};
use derive_more::From;
use easy_ext::ext;
use enum_map::{enum_map, EnumMap};
use enum_map::EnumMap;
use futures_io::{AsyncBufRead, AsyncRead, AsyncSeek};
use futures_util::future::{OptionFuture, TryFutureExt as _};
use itertools::Itertools as _;
Expand All @@ -23,7 +23,7 @@ use crate::{
asyncs::{Async, Mutex as _},
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation},
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain},
InferenceDomain,
},
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId},
Expand Down Expand Up @@ -128,7 +128,7 @@ impl<A: Async> Inner<A> {

let header = VoiceModelHeader::new(manifest, metas, path)?.into();

InnerTryBuilder {
return InnerTryBuilder {
header,
inference_model_entries_builder: |header| {
let VoiceModelHeader { manifest, .. } = &**header;
Expand All @@ -139,21 +139,10 @@ impl<A: Async> Inner<A> {
talk: |talk| {
talk.as_ref()
.map(|manifest| {
let indices = enum_map! {
TalkOperation::PredictDuration => {
find_entry_index(&manifest.predict_duration_filename)?
}
TalkOperation::PredictIntonation => {
find_entry_index(&manifest.predict_intonation_filename)?
}
TalkOperation::GenerateFullIntermediate => {
find_entry_index(&manifest.generate_full_intermediate_filename)?
}
TalkOperation::RenderAudioSegment => {
find_entry_index(&manifest.render_audio_segment_filename)?
}
};

let indices = EnumMap::from_fn(|k| &manifest[k])
.into_array()
.try_map_(|s| find_entry_index(s))?;
let indices = EnumMap::from_array(indices);
Ok(InferenceModelEntry { indices, manifest })
})
.transpose()
Expand All @@ -172,7 +161,20 @@ impl<A: Async> Inner<A> {
},
zip: zip.into_inner().into_inner().into(),
}
.try_build()
.try_build();

#[ext(ArrayExt)]
impl<T, const N: usize> [T; N] {
fn try_map_<O, E>(self, f: impl FnMut(T) -> Result<O, E>) -> Result<[O; N], E> {
self.into_iter()
.map(f)
.collect::<Result<Vec<_>, _>>()
.map(|vec| {
vec.try_into()
.unwrap_or_else(|_| unreachable!("should be same length"))
})
}
}
}

fn id(&self) -> VoiceModelId {
Expand Down
33 changes: 0 additions & 33 deletions crates/voicevox_core_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod extract;
mod inference_domain;
mod inference_domains;
mod manifest;

use syn::parse_macro_input;

Expand Down Expand Up @@ -103,38 +102,6 @@ pub fn derive_inference_output_signature(
from_syn(inference_domain::derive_inference_output_signature(input))
}

/// 構造体のフィールドを取得できる`std::ops::Index`の実装を導出する。
///
/// # Example
///
/// ```
/// use macros::IndexForFields;
///
/// #[derive(IndexForFields)]
/// #[index_for_fields(TalkOperation)]
/// pub(crate) struct TalkManifest {
/// #[index_for_fields(TalkOperation::PredictDuration)]
/// pub(crate) predict_duration_filename: Arc<str>,
///
/// #[index_for_fields(TalkOperation::PredictIntonation)]
/// pub(crate) predict_intonation_filename: Arc<str>,
///
/// #[index_for_fields(TalkOperation::GenerateFullIntermediate)]
/// pub(crate) generate_full_intermediate_filename: Arc<str>,
///
/// #[index_for_fields(TalkOperation::RenderAudioSegment)]
/// pub(crate) render_audio_segment_filename: Arc<str>,
///
/// // …
/// }
/// ```
#[cfg(not(doctest))]
#[proc_macro_derive(IndexForFields, attributes(index_for_fields))]
pub fn derive_index_for_fields(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = &parse_macro_input!(input);
from_syn(manifest::derive_index_for_fields(input))
}

/// # Example
///
/// ```
Expand Down
72 changes: 0 additions & 72 deletions crates/voicevox_core_macros/src/manifest.rs

This file was deleted.

0 comments on commit 3168837

Please sign in to comment.