diff --git a/config.example.toml b/config.example.toml index e29d1068..4243aa55 100644 --- a/config.example.toml +++ b/config.example.toml @@ -46,6 +46,7 @@ redis_url = "redis://localhost:6379" # Target services: you usually only have one, and # it's usually on localhost port 7269 stt_services = [ + "localhost:7269", # or: ["127.0.0.1", 7269] ] diff --git a/scripty_config/src/cfg.rs b/scripty_config/src/cfg.rs index 25717f77..b4b6da9e 100644 --- a/scripty_config/src/cfg.rs +++ b/scripty_config/src/cfg.rs @@ -34,7 +34,7 @@ pub struct BotConfig { pub error_webhook: String, /// List of \["host", port] for the STT services. - pub stt_services: Vec<(String, u16)>, + pub stt_services: Vec, /// Loki config pub loki: LokiConfig, @@ -70,6 +70,13 @@ pub struct DmSupport { pub guild_id: u64, } +#[derive(Clone, Serialize, Deserialize, Debug)] +#[serde[untagged]] +pub enum SttServiceDefinition { + IPTuple(String, u16), + HostString(String), +} + #[derive(Serialize, Deserialize, Debug)] pub struct LokiConfig { /// Loki ingest URL @@ -88,3 +95,42 @@ pub enum BotListsConfig { TokenOnly(String), FullConfig { token: String, webhook: String }, } + +#[cfg(test)] +mod tests { + use std::{ + matches, + net::{IpAddr, Ipv4Addr, SocketAddr}, + }; + + use crate::*; + + #[test] + fn test_stt_service_definition() { + #[derive(Deserialize)] + struct BotConfigTest { + svc: Vec, + } + + let parsed_cfg: BotConfigTest = + toml::from_str("svc = [\"localhost:1234\", [\"192.168.0.1\", 1234]]").unwrap(); + assert!(matches!( + parsed_cfg.svc[0], + SttServiceDefinition::HostString(_) + )); + assert!(matches!( + parsed_cfg.svc[1], + SttServiceDefinition::IPTuple(_, 1234) + )); + + match parsed_cfg.svc[1].clone() { + SttServiceDefinition::IPTuple(addr, port) => { + assert_eq!( + SocketAddr::new(addr.parse().unwrap(), port), + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 1234) + ) + } + SttServiceDefinition::HostString(_) => panic!(), + }; + } +} diff --git a/scripty_stt/src/load_balancer.rs b/scripty_stt/src/load_balancer.rs index d7c599d5..d92a4093 100644 --- a/scripty_stt/src/load_balancer.rs +++ b/scripty_stt/src/load_balancer.rs @@ -9,9 +9,11 @@ use std::{ use dashmap::DashMap; use once_cell::sync::OnceCell; +use scripty_config::SttServiceDefinition; use tokio::{ io, io::{AsyncReadExt, AsyncWriteExt}, + net::lookup_host, }; use crate::{ModelError, Stream}; @@ -31,14 +33,25 @@ pub struct LoadBalancer { impl LoadBalancer { pub async fn new() -> io::Result { - let peer_addresses = scripty_config::get_config() - .stt_services - .iter() - .map(|(addr, port)| SocketAddr::new(addr.parse().unwrap(), *port)) - .enumerate(); + let stt_services = scripty_config::get_config().stt_services.clone(); + let mut peer_addresses: Vec = Vec::new(); + for service in stt_services { + match service { + SttServiceDefinition::HostString(host) => peer_addresses.extend( + lookup_host(host) + .await + .expect("Could not resolve stt hostname"), + ), + SttServiceDefinition::IPTuple(addr, port) => peer_addresses.push(SocketAddr::new( + addr.parse() + .expect("Could not parse IP address for stt server"), + port, + )), + } + } let workers = DashMap::new(); - for (n, addr) in peer_addresses { + for (n, addr) in peer_addresses.into_iter().enumerate() { workers.insert(n, LoadBalancedStream::new(addr).await?); } Ok(Self {