Skip to content

Commit

Permalink
refactor: deduplicate voice listing code
Browse files Browse the repository at this point in the history
  • Loading branch information
kxxt committed May 5, 2023
1 parent 059941e commit 19f8133
Showing 1 changed file with 24 additions and 41 deletions.
65 changes: 24 additions & 41 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ use std::{borrow::Cow, path::PathBuf};

use cli::{commands::Command, Cli};

use aspeak::{AspeakError, AudioFormat, SynthesizerConfig, Voice, QUALITY_MAP};
use clap::Parser;
use color_eyre::{
eyre::{anyhow, bail},
Help, Report,
use aspeak::{
voice::{VoiceListAPIAuth, VoiceListAPIEndpoint, VoiceListAPIError, VoiceListAPIErrorKind},
AspeakError, AudioFormat, SynthesizerConfig, Voice, QUALITY_MAP,
};
use clap::Parser;
use color_eyre::{eyre::anyhow, Help, Report};
use colored::Colorize;
use constants::ORIGIN;

use env_logger::WriteStyle;
use log::debug;

use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::header::HeaderMap;
use strum::IntoEnumIterator;
use tokio_tungstenite::tungstenite::{error::ProtocolError, Error as TungsteniteError};

Expand All @@ -26,8 +25,6 @@ use crate::cli::{
config::{Config, EndpointConfig},
};

const TRIAL_VOICE_LIST_URL: Option<&str> = None;

fn main() -> color_eyre::eyre::Result<()> {
let mut cli = Cli::parse();
if cli.color == Color::Auto && std::env::var_os("NO_COLOR").is_some() {
Expand All @@ -51,7 +48,7 @@ fn main() -> color_eyre::eyre::Result<()> {
debug!("Commandline args: {cli:?}");
debug!("Profile: {config:?}");
let Cli { command, auth, .. } = cli;
let auth_options = auth.to_auth_options(config.as_ref().and_then(|c|c.auth.as_ref()))?;
let auth_options = auth.to_auth_options(config.as_ref().and_then(|c| c.auth.as_ref()))?;
debug!("Auth options: {auth_options:?}");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
Expand All @@ -69,7 +66,6 @@ fn main() -> color_eyre::eyre::Result<()> {
.or_else(|_| Cli::process_input_text(&input_args))?;
let audio_format = output_args.get_audio_format(config.as_ref().and_then(|c|c.output.as_ref()))?;
let callback = Cli::process_output(output_args.output, output_args.overwrite)?;

let mut synthesizer = SynthesizerConfig::new(auth_options, audio_format)
.connect()
.await?;
Expand Down Expand Up @@ -133,41 +129,28 @@ fn main() -> color_eyre::eyre::Result<()> {
)
).map(|r| Cow::Owned(format!("https://{r}.tts.speech.microsoft.com/cognitiveservices/voices/list")))
})
.or_else(|| TRIAL_VOICE_LIST_URL.map(Cow::Borrowed))
// .or_else(|| TRIAL_VOICE_LIST_URL.map(Cow::Borrowed))
.ok_or_else(
|| Report::new(AspeakError::ArgumentError("No voice list API url specified!".to_string()))
.with_note(|| "The default voice list API that is used in aspeak v4 has been shutdown and is no longer available.")
.with_suggestion(|| "You can still use the list-voices command by specifying a region(authentication needed) or a custom voice list API url.")
)?;
let auth = auth.to_auth_options(config.as_ref().and_then(|c|c.auth.as_ref()))?;
let mut client = reqwest::ClientBuilder::new().no_proxy(); // Disable default system proxy detection.
if let Some(proxy) = auth.proxy() {
client = client.proxy(reqwest::Proxy::all(proxy)?);
}
let client = client.build()?;
let mut request = client.get(&*url);
if let Some(key) = auth.key() {
request = request.header(
"Ocp-Apim-Subscription-Key",
HeaderValue::from_str(key)
.map_err(|e| AspeakError::ArgumentError(e.to_string()))?,
);
}
if !auth.headers().is_empty() {
// TODO: I don't know if this could be further optimized
request = request.headers(HeaderMap::from_iter(auth.headers().iter().map(Clone::clone)));
} else if Some(url.as_ref()) == TRIAL_VOICE_LIST_URL {
// Trial endpoint
request = request.header("Origin", HeaderValue::from_str(ORIGIN).unwrap());
}
let request = request.build()?;
let response = client.execute(request).await?;
if response.status().is_client_error() {
bail!(anyhow!("Failed to retrieve voice list because of client side error.").with_note(|| "Maybe you are not authorized. Did you specify an auth token or a subscription key? Did the key/token expire?"))
} else if response.status().is_server_error() {
bail!("Failed to retrieve voice list because of server side error.")
}
let voices = response.json::<Vec<Voice>>().await?;
let auth = match (auth_options.key(), auth_options.token()) {
(_, Some(token)) => Some(VoiceListAPIAuth::AuthToken(token)),
(Some(key), None) => Some(VoiceListAPIAuth::SubscriptionKey(key)),
(None, None) => None,
};
let voices_result = Voice::request_available_voices_with_additional_headers(VoiceListAPIEndpoint::Url(url.as_ref()),
auth, auth_options.proxy(), Some(HeaderMap::from_iter(auth_options.headers().iter().map(Clone::clone)))
).await;
let voices = if let Err(VoiceListAPIError {
kind: VoiceListAPIErrorKind::ResponseError,
..
}) = voices_result {
voices_result.with_note(|| "Maybe you are not authorized. Did you specify an auth token or a subscription key? Did the key/token expire?")?
} else {
voices_result?
};
let voices = voices.iter();
let locale_id = locale.as_deref();
let voice_id = voice.as_deref();
Expand Down

0 comments on commit 19f8133

Please sign in to comment.