Skip to content

Commit

Permalink
[Rust] voicevox_tts_from_kana の実装 (VOICEVOX#193)
Browse files Browse the repository at this point in the history
* implements voicevox_tts_from_kana

* allow dead_code をもう少し整理

* wavの書き込み処理を共通化

* FullContextLabelErrorとKanaParseErrorについて、fromでErrorに変換できるようにする
  • Loading branch information
PickledChair authored Jul 20, 2022
1 parent 13e071e commit a6272c0
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 52 deletions.
34 changes: 24 additions & 10 deletions crates/voicevox_core/src/c_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum VoicevoxResultCode {
VOICEVOX_RESULT_INFERENCE_FAILED = 9,
VOICEVOX_RESULT_FAILED_EXTRACT_FULL_CONTEXT_LABEL = 10,
VOICEVOX_RESULT_INVALID_UTF8_INPUT = 11,
VOICEVOX_RESULT_FAILED_PARSE_KANA = 12,
}

fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
Expand Down Expand Up @@ -80,6 +81,9 @@ fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
None,
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_EXTRACT_FULL_CONTEXT_LABEL,
),
Error::FailedParseKana(_) => {
(None, VoicevoxResultCode::VOICEVOX_RESULT_FAILED_PARSE_KANA)
}
}
}
}
Expand Down Expand Up @@ -250,6 +254,13 @@ pub extern "C" fn voicevox_load_openjtalk_dict(dict_path: *const c_char) -> Voic
result_code
}

unsafe fn write_wav_to_ptr(output_wav_ptr: *mut *mut u8, output_size_ptr: *mut c_int, data: &[u8]) {
output_size_ptr.write(data.len() as c_int);
let wav_heap = libc::malloc(data.len());
libc::memcpy(wav_heap, data.as_ptr() as *const c_void, data.len());
output_wav_ptr.write(wav_heap as *mut u8);
}

#[no_mangle]
pub extern "C" fn voicevox_tts(
text: *const c_char,
Expand All @@ -266,10 +277,7 @@ pub extern "C" fn voicevox_tts(
};
if let Some(output) = output_opt {
unsafe {
output_binary_size.write(output.len() as c_int);
let wav_heap = libc::malloc(output.len());
libc::memcpy(wav_heap, output.as_ptr() as *const c_void, output.len());
output_wav.write(wav_heap as *mut u8);
write_wav_to_ptr(output_wav, output_binary_size, output.as_slice());
}
}
result_code
Expand All @@ -282,12 +290,18 @@ pub extern "C" fn voicevox_tts_from_kana(
output_binary_size: *mut c_int,
output_wav: *mut *mut u8,
) -> VoicevoxResultCode {
let (_, result_code) = convert_result(lock_internal().voicevox_tts_from_kana(
unsafe { CStr::from_ptr(text) },
speaker_id,
output_binary_size,
output_wav,
));
let (output_opt, result_code) = {
if let Ok(text) = unsafe { CStr::from_ptr(text) }.to_str() {
convert_result(lock_internal().voicevox_tts_from_kana(text, speaker_id as usize))
} else {
(None, VoicevoxResultCode::VOICEVOX_RESULT_INVALID_UTF8_INPUT)
}
};
if let Some(output) = output_opt {
unsafe {
write_wav_to_ptr(output_wav, output_binary_size, output.as_slice());
}
}
result_code
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ pub struct OjtPhoneme {
}

impl OjtPhoneme {
#[allow(dead_code)]
pub fn num_phoneme() -> usize {
PHONEME_MAP.len()
}
Expand Down
18 changes: 9 additions & 9 deletions crates/voicevox_core/src/engine/full_context_label.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pub enum FullContextLabelError {

type Result<T> = std::result::Result<T, FullContextLabelError>;

#[allow(dead_code)]
#[derive(new, Getters, Clone, PartialEq, Debug)]
pub struct Phoneme {
contexts: HashMap<String, String>,
Expand Down Expand Up @@ -49,7 +48,6 @@ fn string_feature_by_regex(re: &Regex, label: &str) -> Result<String> {
}
}

#[allow(dead_code)]
impl Phoneme {
pub fn from_label(label: impl Into<String>) -> Result<Self> {
let mut contexts = HashMap::<String, String>::with_capacity(10);
Expand Down Expand Up @@ -77,14 +75,12 @@ impl Phoneme {
}
}

#[allow(dead_code)]
#[derive(new, Getters, Clone, PartialEq, Debug)]
pub struct Mora {
consonant: Option<Phoneme>,
vowel: Phoneme,
}

#[allow(dead_code)]
impl Mora {
pub fn set_context(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
Expand All @@ -106,6 +102,7 @@ impl Mora {
}
}

#[allow(dead_code)]
pub fn labels(&self) -> Vec<String> {
self.phonemes().iter().map(|p| p.label().clone()).collect()
}
Expand All @@ -118,7 +115,6 @@ pub struct AccentPhrase {
is_interrogative: bool,
}

#[allow(dead_code)]
impl AccentPhrase {
pub fn from_phonemes(mut phonemes: Vec<Phoneme>) -> Result<Self> {
let mut moras = Vec::with_capacity(phonemes.len());
Expand Down Expand Up @@ -175,6 +171,7 @@ impl AccentPhrase {
Ok(Self::new(moras, accent, is_interrogative))
}

#[allow(dead_code)]
pub fn set_context(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
let value = value.into();
Expand All @@ -187,10 +184,12 @@ impl AccentPhrase {
self.moras.iter().flat_map(|m| m.phonemes()).collect()
}

#[allow(dead_code)]
pub fn labels(&self) -> Vec<String> {
self.phonemes().iter().map(|p| p.label().clone()).collect()
}

#[allow(dead_code)]
pub fn merge(&self, accent_phrase: AccentPhrase) -> AccentPhrase {
let mut moras = self.moras().clone();
let is_interrogative = *accent_phrase.is_interrogative();
Expand All @@ -199,13 +198,11 @@ impl AccentPhrase {
}
}

#[allow(dead_code)]
#[derive(new, Getters, Clone, PartialEq, Debug)]
pub struct BreathGroup {
accent_phrases: Vec<AccentPhrase>,
}

#[allow(dead_code)]
impl BreathGroup {
pub fn from_phonemes(phonemes: Vec<Phoneme>) -> Result<Self> {
let mut accent_phrases = Vec::with_capacity(phonemes.len());
Expand All @@ -226,6 +223,7 @@ impl BreathGroup {
Ok(Self::new(accent_phrases))
}

#[allow(dead_code)]
pub fn set_context(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
let value = value.into();
Expand All @@ -241,19 +239,18 @@ impl BreathGroup {
.collect()
}

#[allow(dead_code)]
pub fn labels(&self) -> Vec<String> {
self.phonemes().iter().map(|p| p.label().clone()).collect()
}
}

#[allow(dead_code)]
#[derive(new, Getters, Clone, PartialEq, Debug)]
pub struct Utterance {
breath_groups: Vec<BreathGroup>,
pauses: Vec<Phoneme>,
}

#[allow(dead_code)]
impl Utterance {
pub fn from_phonemes(phonemes: Vec<Phoneme>) -> Result<Self> {
let mut breath_groups = vec![];
Expand All @@ -274,6 +271,7 @@ impl Utterance {
Ok(Self::new(breath_groups, pauses))
}

#[allow(dead_code)]
pub fn set_context(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key = key.into();
let value = value.into();
Expand All @@ -282,6 +280,7 @@ impl Utterance {
}
}

#[allow(dead_code)]
pub fn phonemes(&self) -> Vec<Phoneme> {
// TODO:実装が中途半端なのであとでちゃんと実装する必要があるらしい
// https://github.com/VOICEVOX/voicevox_core/pull/174#discussion_r919982651
Expand All @@ -297,6 +296,7 @@ impl Utterance {
phonemes
}

#[allow(dead_code)]
pub fn labels(&self) -> Vec<String> {
self.phonemes().iter().map(|p| p.label().clone()).collect()
}
Expand Down
7 changes: 3 additions & 4 deletions crates/voicevox_core/src/engine/kana_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const PAUSE_DELIMITER: char = '、';
const WIDE_INTERROGATION_MARK: char = '?';
const LOOP_LIMIT: usize = 300;

#[derive(Clone, Debug)]
struct KanaParseError(String);
#[derive(Clone, Debug, PartialEq)]
pub struct KanaParseError(String);

impl std::fmt::Display for KanaParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -126,8 +126,7 @@ fn text_to_accent_phrase(phrase: &str) -> KanaParseResult<AccentPhraseModel> {
))
}

#[allow(dead_code)] // TODO: remove this feature
fn parse_kana(text: &str) -> KanaParseResult<Vec<AccentPhraseModel>> {
pub fn parse_kana(text: &str) -> KanaParseResult<Vec<AccentPhraseModel>> {
const TERMINATOR: char = '\0';
let mut parsed_result = Vec::new();
let chars_of_text = text.chars().chain([TERMINATOR]);
Expand Down
1 change: 1 addition & 0 deletions crates/voicevox_core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ use super::*;

pub use acoustic_feature_extractor::*;
pub use full_context_label::*;
pub use kana_parser::*;
pub use model::*;
pub use synthesis_engine::*;
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/engine/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl AccentPhraseModel {
}
}

#[allow(dead_code, clippy::too_many_arguments)] // TODO: remove allow(dead_code)
#[allow(clippy::too_many_arguments)]
#[derive(new, Getters)]
pub struct AudioQueryModel {
accent_phrases: Vec<AccentPhraseModel>,
Expand All @@ -41,5 +41,6 @@ pub struct AudioQueryModel {
post_phoneme_length: f32,
output_sampling_rate: u32,
output_stereo: bool,
#[allow(dead_code)]
kana: String,
}
8 changes: 0 additions & 8 deletions crates/voicevox_core/src/engine/open_jtalk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ use std::path::{Path, PathBuf};

use ::open_jtalk::*;

/*
* TODO: OpenJtalk機能を使用するようになったら、allow(dead_code)を消す
*/

#[allow(dead_code)]
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum OpenJtalkError {
#[error("open_jtalk load error")]
Expand All @@ -19,10 +14,8 @@ pub enum OpenJtalkError {
},
}

#[allow(dead_code)]
pub type Result<T> = std::result::Result<T, OpenJtalkError>;

#[allow(dead_code)]
pub struct OpenJtalk {
mecab: ManagedResource<Mecab>,
njd: ManagedResource<Njd>,
Expand All @@ -40,7 +33,6 @@ impl OpenJtalk {
}
}

#[allow(dead_code)]
pub fn extract_fullcontext(&mut self, text: impl AsRef<str>) -> Result<Vec<String>> {
let mecab_text =
text2mecab(text.as_ref()).map_err(|e| OpenJtalkError::ExtractFullContext {
Expand Down
5 changes: 1 addition & 4 deletions crates/voicevox_core/src/engine/synthesis_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ unsafe impl Send for SynthesisEngine {}
#[allow(unsafe_code)]
unsafe impl Sync for SynthesisEngine {}

#[allow(dead_code)]
#[allow(unused_variables)]
impl SynthesisEngine {
pub const DEFAULT_SAMPLING_RATE: u32 = 24000;

Expand Down Expand Up @@ -51,8 +49,7 @@ impl SynthesisEngine {
return Ok(Vec::new());
}

let utterance = Utterance::extract_full_context_label(&mut self.open_jtalk, text.as_ref())
.map_err(Error::FailedExtractFullContextLabel)?;
let utterance = Utterance::extract_full_context_label(&mut self.open_jtalk, text.as_ref())?;

let accent_phrases: Vec<AccentPhraseModel> = utterance
.breath_groups()
Expand Down
9 changes: 5 additions & 4 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Display;

use crate::engine::FullContextLabelError;
use crate::engine::{FullContextLabelError, KanaParseError};

use super::*;
use c_export::VoicevoxResultCode::{self, *};
Expand All @@ -18,8 +18,6 @@ pub enum Error {
* エラーメッセージのベースとなる文字列は必ずbase_error_message関数を使用してVoicevoxResultCodeのエラー出力の内容と対応するようにすること
*/
#[error("{}", base_error_message(VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT))]
// TODO:仮実装がlinterエラーにならないようにするための属性なのでこのenumが正式に使われる際にallow(dead_code)を取り除くこと
#[allow(dead_code)]
NotLoadedOpenjtalkDict,

#[error("{}", base_error_message(VOICEVOX_RESULT_CANT_GPU_SUPPORT))]
Expand Down Expand Up @@ -53,7 +51,10 @@ pub enum Error {
"{},{0}",
base_error_message(VOICEVOX_RESULT_FAILED_EXTRACT_FULL_CONTEXT_LABEL)
)]
FailedExtractFullContextLabel(#[source] FullContextLabelError),
FailedExtractFullContextLabel(#[from] FullContextLabelError),

#[error("{},{0}", base_error_message(VOICEVOX_RESULT_FAILED_PARSE_KANA))]
FailedParseKana(#[from] KanaParseError),
}

fn base_error_message(result_code: VoicevoxResultCode) -> &'static str {
Expand Down
35 changes: 24 additions & 11 deletions crates/voicevox_core/src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use onnxruntime::{
};
use std::collections::BTreeMap;
use std::ffi::CStr;
use std::os::raw::c_int;
use std::sync::Mutex;

use status::*;
Expand Down Expand Up @@ -149,16 +148,27 @@ impl Internal {
.synthesis_wave_format(&audio_query, speaker_id, true) // TODO: 疑問文化を設定可能にする
}

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
#[allow(unused_variables)]
pub fn voicevox_tts_from_kana(
&self,
text: &CStr,
speaker_id: i64,
output_binary_size: *mut c_int,
output_wav: *const *mut u8,
) -> Result<()> {
unimplemented!()
pub fn voicevox_tts_from_kana(&mut self, text: &str, speaker_id: usize) -> Result<Vec<u8>> {
let accent_phrases = parse_kana(text)?;
let accent_phrases = self
.synthesis_engine
.replace_mora_data(&accent_phrases, speaker_id)?;

let audio_query = AudioQueryModel::new(
accent_phrases,
1.,
0.,
1.,
1.,
0.1,
0.1,
SynthesisEngine::DEFAULT_SAMPLING_RATE,
false,
"".into(),
);

self.synthesis_engine
.synthesis_wave_format(&audio_query, speaker_id, true) // TODO: 疑問文化を設定可能にする
}
}

Expand Down Expand Up @@ -509,6 +519,9 @@ pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -
"入力テキストからのフルコンテキストラベル抽出に失敗しました\0"
}
VOICEVOX_RESULT_INVALID_UTF8_INPUT => "入力テキストが無効なUTF-8データでした\0",
VOICEVOX_RESULT_FAILED_PARSE_KANA => {
"入力テキストをAquesTalkライクな読み仮名としてパースすることに失敗しました\0"
}
}
}

Expand Down

0 comments on commit a6272c0

Please sign in to comment.