Skip to content

Commit

Permalink
Merge pull request #17 from wilt00/resolve-stt-hostnames
Browse files Browse the repository at this point in the history
Resolve hostnames for STT server using DNS
  • Loading branch information
tazz4843 authored Dec 2, 2023
2 parents 46ce53c + 14c4d08 commit 5486b7d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
1 change: 1 addition & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ redis_url = "redis://localhost:6379"
# Target services: you usually only have one, and
# it's usually on localhost port 7269
stt_services = [
"localhost:7269", # or:
["127.0.0.1", 7269]
]

Expand Down
48 changes: 47 additions & 1 deletion scripty_config/src/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct BotConfig {
pub error_webhook: String,

/// List of \["host", port] for the STT services.
pub stt_services: Vec<(String, u16)>,
pub stt_services: Vec<SttServiceDefinition>,

/// Loki config
pub loki: LokiConfig,
Expand Down Expand Up @@ -70,6 +70,13 @@ pub struct DmSupport {
pub guild_id: u64,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde[untagged]]
pub enum SttServiceDefinition {
IPTuple(String, u16),
HostString(String),
}

#[derive(Serialize, Deserialize, Debug)]
pub struct LokiConfig {
/// Loki ingest URL
Expand All @@ -88,3 +95,42 @@ pub enum BotListsConfig {
TokenOnly(String),
FullConfig { token: String, webhook: String },
}

#[cfg(test)]
mod tests {
use std::{
matches,
net::{IpAddr, Ipv4Addr, SocketAddr},
};

use crate::*;

#[test]
fn test_stt_service_definition() {
#[derive(Deserialize)]
struct BotConfigTest {
svc: Vec<SttServiceDefinition>,
}

let parsed_cfg: BotConfigTest =
toml::from_str("svc = [\"localhost:1234\", [\"192.168.0.1\", 1234]]").unwrap();
assert!(matches!(
parsed_cfg.svc[0],
SttServiceDefinition::HostString(_)
));
assert!(matches!(
parsed_cfg.svc[1],
SttServiceDefinition::IPTuple(_, 1234)
));

match parsed_cfg.svc[1].clone() {
SttServiceDefinition::IPTuple(addr, port) => {
assert_eq!(
SocketAddr::new(addr.parse().unwrap(), port),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 1234)
)
}
SttServiceDefinition::HostString(_) => panic!(),
};
}
}
25 changes: 19 additions & 6 deletions scripty_stt/src/load_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ use std::{

use dashmap::DashMap;
use once_cell::sync::OnceCell;
use scripty_config::SttServiceDefinition;
use tokio::{
io,
io::{AsyncReadExt, AsyncWriteExt},
net::lookup_host,
};

use crate::{ModelError, Stream};
Expand All @@ -31,14 +33,25 @@ pub struct LoadBalancer {

impl LoadBalancer {
pub async fn new() -> io::Result<Self> {
let peer_addresses = scripty_config::get_config()
.stt_services
.iter()
.map(|(addr, port)| SocketAddr::new(addr.parse().unwrap(), *port))
.enumerate();
let stt_services = scripty_config::get_config().stt_services.clone();
let mut peer_addresses: Vec<SocketAddr> = Vec::new();
for service in stt_services {
match service {
SttServiceDefinition::HostString(host) => peer_addresses.extend(
lookup_host(host)
.await
.expect("Could not resolve stt hostname"),
),
SttServiceDefinition::IPTuple(addr, port) => peer_addresses.push(SocketAddr::new(
addr.parse()
.expect("Could not parse IP address for stt server"),
port,
)),
}
}

let workers = DashMap::new();
for (n, addr) in peer_addresses {
for (n, addr) in peer_addresses.into_iter().enumerate() {
workers.insert(n, LoadBalancedStream::new(addr).await?);
}
Ok(Self {
Expand Down

0 comments on commit 5486b7d

Please sign in to comment.