From 08100d0854dbc05860461c721dd85c72917d3130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Artur=20Wyszy=C5=84ski?= Date: Fri, 6 Dec 2024 20:27:01 +0100 Subject: [PATCH] refactor: implement event-driven connection statistics - Extract connection statistics into dedicated StatsManager - Add configurable stats management with StatsConfig - Implement event-based stats collection using StatEvent enum - Move connection tracking logic from ConnectionManager to StatsManager - Add comprehensive test coverage for connection lifecycle - Replace direct counter updates with async event channel - Improve error handling and logging for stats operations - Fix connection cleanup by separating stats and connection state --- src/config/mod.rs | 2 + src/config/stats.rs | 25 +++ src/connection/events.rs | 27 ++- src/connection/guard.rs | 12 +- src/connection/manager.rs | 333 +++++++++-------------------- src/connection/mod.rs | 167 +++++++++++---- src/connection/stats/client.rs | 31 ++- src/connection/stats/connection.rs | 79 ++++++- src/connection/stats/ip.rs | 14 +- src/http_api.rs | 232 ++++++++++++++++++-- src/lib.rs | 6 +- src/modbus_relay.rs | 246 +++++++++------------ src/stats_manager.rs | 248 +++++++++++++++++++++ 13 files changed, 952 insertions(+), 470 deletions(-) create mode 100644 src/config/stats.rs create mode 100644 src/stats_manager.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 6071719..8b2a1e3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,6 +4,7 @@ mod http; mod logging; mod relay; mod rtu; +mod stats; mod tcp; mod types; @@ -13,5 +14,6 @@ pub use http::Config as HttpConfig; pub use logging::Config as LoggingConfig; pub use relay::Config as RelayConfig; pub use rtu::Config as RtuConfig; +pub use stats::Config as StatsConfig; pub use tcp::Config as TcpConfig; pub use types::{DataBits, Parity, RtsType, StopBits}; diff --git a/src/config/stats.rs b/src/config/stats.rs new file mode 100644 index 0000000..c901c50 --- /dev/null +++ b/src/config/stats.rs @@ -0,0 +1,25 @@ +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + #[serde(with = "humantime_serde")] + pub cleanup_interval: Duration, + #[serde(with = "humantime_serde")] + pub idle_timeout: Duration, + #[serde(with = "humantime_serde")] + pub error_timeout: Duration, + pub max_events_per_second: u32, +} + +impl Default for Config { + fn default() -> Self { + Self { + cleanup_interval: Duration::from_secs(60), + idle_timeout: Duration::from_secs(300), + error_timeout: Duration::from_secs(300), + max_events_per_second: 10000, + } + } +} diff --git a/src/connection/events.rs b/src/connection/events.rs index 3eaf52e..c9a53be 100644 --- a/src/connection/events.rs +++ b/src/connection/events.rs @@ -1,7 +1,26 @@ -use std::time::Duration; +use super::{stats::ClientStats, ConnectionStats}; +use std::net::SocketAddr; +use tokio::sync::oneshot; -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum StatEvent { - Request { success: bool }, - ResponseTime(Duration), + /// Client connected from address + ClientConnected(SocketAddr), + /// Client disconnected from address + ClientDisconnected(SocketAddr), + /// Request processed with success/failure and duration + RequestProcessed { + addr: SocketAddr, + success: bool, + duration_ms: u64, + }, + /// Query stats for specific address + QueryStats { + addr: SocketAddr, + response_tx: oneshot::Sender, + }, + /// Query global connection stats + QueryConnectionStats { + response_tx: oneshot::Sender, + }, } diff --git a/src/connection/guard.rs b/src/connection/guard.rs index 1e4e8d0..5529cdb 100644 --- a/src/connection/guard.rs +++ b/src/connection/guard.rs @@ -1,5 +1,4 @@ use std::{net::SocketAddr, sync::Arc}; - use tokio::sync::OwnedSemaphorePermit; use tracing::debug; @@ -16,20 +15,13 @@ pub struct ConnectionGuard { impl Drop for ConnectionGuard { fn drop(&mut self) { - let manager = Arc::clone(&self.manager); + let manager = self.manager.clone(); let addr = self.addr; debug!("Closing connection from {}", addr); tokio::spawn(async move { - let mut stats = manager.stats.lock().await; - if let Some(client_stats) = stats.get_mut(&addr) { - client_stats.active_connections -= 1; - debug!( - "Connection from {} closed, active connections: {}", - addr, client_stats.active_connections - ); - } + manager.decrease_connection_count(addr).await; }); } } diff --git a/src/connection/manager.rs b/src/connection/manager.rs index 14e1657..098a739 100644 --- a/src/connection/manager.rs +++ b/src/connection/manager.rs @@ -1,55 +1,35 @@ -use std::{ - collections::{HashMap, VecDeque}, - net::SocketAddr, - sync::{ - atomic::{AtomicU32, AtomicU64, Ordering}, - Arc, - }, - time::{Duration, Instant}, -}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; -use tokio::sync::{Mutex, Semaphore}; -use tracing::info; +use tokio::sync::{mpsc, oneshot, Mutex, Semaphore}; +use tracing::{error, info}; use crate::{config::ConnectionConfig, ConnectionError, RelayError}; -use super::{ClientStats, ConnectionGuard, ConnectionStats, IpStats}; +use super::{ConnectionGuard, ConnectionStats, StatEvent}; /// TCP connection management #[derive(Debug)] pub struct Manager { /// Connection limit per IP - pub per_ip_semaphores: Arc>>>, + per_ip_semaphores: Arc>>>, /// Global connection limit - pub global_semaphore: Arc, - /// Stats per IP - pub stats: Arc>>, + global_semaphore: Arc, + /// Active connections counter per IP + active_connections: Arc>>, /// Configuration - pub config: ConnectionConfig, - /// Counter of all connections - pub total_connections: Arc, - /// Total requests - pub total_requests: AtomicU64, - /// Error count - pub error_count: AtomicU32, - /// Start time - pub start_time: Instant, - /// Response times - pub response_times: Arc>>, + config: ConnectionConfig, + /// Stats event sender + stats_tx: mpsc::Sender, } impl Manager { - pub fn new(config: ConnectionConfig) -> Self { + pub fn new(config: ConnectionConfig, stats_tx: mpsc::Sender) -> Self { Self { per_ip_semaphores: Arc::new(Mutex::new(HashMap::new())), global_semaphore: Arc::new(Semaphore::new(config.max_connections as usize)), - stats: Arc::new(Mutex::new(HashMap::new())), + active_connections: Arc::new(Mutex::new(HashMap::new())), config, - total_connections: Arc::new(AtomicU64::new(0)), - total_requests: AtomicU64::new(0), - error_count: AtomicU32::new(0), - start_time: Instant::now(), - response_times: Arc::new(Mutex::new(VecDeque::with_capacity(100))), + stats_tx, } } @@ -87,29 +67,17 @@ impl Manager { )) })?; - // Update statistics + // Increment active connections counter { - let mut stats = self.stats.lock().await; - let client_stats = stats.entry(addr).or_insert_with(|| ClientStats { - active_connections: 0, - last_active: Instant::now(), - total_requests: 0, - error_count: 0, - last_error: None, - }); - - // Check for potential overflow - if client_stats.active_connections == usize::MAX { - return Err(RelayError::Connection(ConnectionError::invalid_state( - "Active connections counter overflow".to_string(), - ))); - } - - client_stats.active_connections += 1; - client_stats.last_active = Instant::now(); + let mut active_conns = self.active_connections.lock().await; + let conn_count = active_conns.entry(addr).or_default(); + *conn_count = conn_count.saturating_add(1); } - self.total_connections.fetch_add(1, Ordering::Relaxed); + // Notify stats manager about new connection + if let Err(e) = self.stats_tx.send(StatEvent::ClientConnected(addr)).await { + error!("Failed to send connection event to stats manager: {}", e); + } Ok(ConnectionGuard { manager: Arc::clone(self), @@ -120,205 +88,118 @@ impl Manager { } pub async fn close_all_connections(&self) -> Result<(), RelayError> { - let stats = self.stats.lock().await; - let active_connections = stats.values().map(|s| s.active_connections).sum::(); + let active_conns = self.active_connections.lock().await; + let total_active: usize = active_conns.values().sum(); - if active_connections > 0 { - info!("Closing {} active connections", active_connections); - // TODO(aljen): Here we can add code to forcefully close connections - // e.g., by sending a signal to all ConnectionGuard + if total_active > 0 { + info!("Closing {} active connections", total_active); + // TODO(aljen): Logic for force-closing connections should be added here } Ok(()) } - pub async fn record_client_error(&self, addr: &SocketAddr) -> Result<(), RelayError> { - let mut stats = self.stats.lock().await; - let client_stats = stats.entry(*addr).or_insert_with(|| ClientStats { - active_connections: 0, - last_active: Instant::now(), - total_requests: 0, - error_count: 0, - last_error: None, - }); - - client_stats.error_count += 1; - client_stats.last_error = Some(Instant::now()); - - Ok(()) - } - - /// Updates statistics for a given connection - pub async fn record_request(&self, addr: SocketAddr, success: bool) { - let mut stats = self.stats.lock().await; - if let Some(client_stats) = stats.get_mut(&addr) { - client_stats.total_requests += 1; - client_stats.last_active = Instant::now(); - if !success { - client_stats.error_count += 1; - client_stats.last_error = Some(Instant::now()); - } - } + pub async fn get_connection_count(&self, addr: &SocketAddr) -> usize { + self.active_connections + .lock() + .await + .get(addr) + .copied() + .unwrap_or(0) } - fn should_cleanup_connection( - stats: &ClientStats, - now: Instant, - idle_timeout: Duration, - error_timeout: Duration, - ) -> bool { - now.duration_since(stats.last_active) >= idle_timeout - || (stats.error_count > 0 - && now.duration_since(stats.last_error.unwrap_or(now)) >= error_timeout) + pub async fn get_total_connections(&self) -> usize { + self.active_connections.lock().await.values().sum() } - /// Cleans up idle connections - pub async fn cleanup_idle_connections(&self) -> Result<(), RelayError> { - let now = Instant::now(); - - // First pass: collect connections to clean - let to_clean: Vec<(SocketAddr, ClientStats)> = { - let stats = self.stats.lock().await; - stats - .iter() - .filter(|(_, s)| { - Self::should_cleanup_connection( - s, - now, - self.config.idle_timeout, - self.config.error_timeout, - ) - }) - .map(|(addr, s)| (*addr, (*s).clone())) - .collect() - }; // stats lock is dropped here - - // Second pass: verify and cleanup - for (addr, stats_snapshot) in to_clean { - let mut stats = self.stats.lock().await; - // Recheck conditions before cleanup - if Self::should_cleanup_connection( - &stats_snapshot, - now, - self.config.idle_timeout, - self.config.error_timeout, - ) { - stats.remove(&addr); - info!( - "Cleaned up connection {} ({} connections, {} errors, last active: {:?} ago)", - addr, - stats_snapshot.active_connections, - stats_snapshot.error_count, - now.duration_since(stats_snapshot.last_active) - ); - } + /// Updates statistics for a given request + pub async fn record_request(&self, addr: SocketAddr, success: bool, duration: Duration) { + if let Err(e) = self + .stats_tx + .send(StatEvent::RequestProcessed { + addr, + success, + duration_ms: duration.as_millis() as u64, + }) + .await + { + error!("Failed to send request stats: {}", e); } - - Ok(()) } - /// Returns connection statistics + /// Gets complete connection statistics pub async fn get_stats(&self) -> Result { - let stats = self.stats.lock().await; - let mut total_active: usize = 0; - let mut total_requests: u64 = 0; - let mut total_errors: u64 = 0; - let mut per_ip_stats = HashMap::new(); - - for (addr, client_stats) in stats.iter() { - // Check for counter overflow - if total_active - .checked_add(client_stats.active_connections) - .is_none() - { - return Err(RelayError::Connection(ConnectionError::invalid_state( - "Total active connections counter overflow".to_string(), - ))); - } - total_active += client_stats.active_connections; - - if total_requests - .checked_add(client_stats.total_requests) - .is_none() - { - return Err(RelayError::Connection(ConnectionError::invalid_state( - "Total requests counter overflow".to_string(), - ))); - } - total_requests += client_stats.total_requests; - - if total_errors.checked_add(client_stats.error_count).is_none() { - return Err(RelayError::Connection(ConnectionError::invalid_state( - "Total errors counter overflow".to_string(), - ))); - } - total_errors += client_stats.error_count; + let (tx, rx) = oneshot::channel(); - per_ip_stats.insert( - *addr, - IpStats { - active_connections: client_stats.active_connections, - total_requests: client_stats.total_requests, - error_count: client_stats.error_count, - last_active: client_stats.last_active, - last_error: client_stats.last_error, - }, - ); - } + self.stats_tx + .send(StatEvent::QueryConnectionStats { response_tx: tx }) + .await + .map_err(|_| { + RelayError::Connection(ConnectionError::invalid_state( + "Failed to query connection stats", + )) + })?; - Ok(ConnectionStats { - total_connections: self.total_connections.load(Ordering::Relaxed), - active_connections: total_active, - total_requests, - total_errors, - per_ip_stats, + rx.await.map_err(|_| { + RelayError::Connection(ConnectionError::invalid_state( + "Failed to receive connection stats", + )) }) } - pub async fn connection_count(&self) -> u32 { - self.stats.lock().await.len() as u32 - } + /// Cleans up idle connections + pub(crate) async fn cleanup_idle_connections(&self) -> Result<(), RelayError> { + // Cleanup is now handled by StatsManager, we just need to sync our active connections + let (tx, rx) = oneshot::channel(); - pub fn total_requests(&self) -> u64 { - self.total_requests.load(Ordering::Relaxed) - } + self.stats_tx + .send(StatEvent::QueryConnectionStats { response_tx: tx }) + .await + .map_err(|_| { + RelayError::Connection(ConnectionError::invalid_state( + "Failed to query stats for cleanup", + )) + })?; - pub fn error_count(&self) -> u32 { - self.error_count.load(Ordering::Relaxed) - } + let stats = rx.await.map_err(|_| { + RelayError::Connection(ConnectionError::invalid_state( + "Failed to receive stats for cleanup", + )) + })?; + + let mut active_conns = self.active_connections.lock().await; + active_conns.retain(|addr, count| { + if let Some(ip_stats) = stats.per_ip_stats.get(addr) { + ip_stats.active_connections > 0 + } else { + // If no stats exist, connection is considered inactive + *count == 0 + } + }); - pub async fn avg_response_time(&self) -> Duration { - let times = self.response_times.lock().await; - if times.is_empty() { - return Duration::from_millis(0); - } - let sum: Duration = times.iter().sum(); - sum / times.len() as u32 + Ok(()) } - pub fn requests_per_second(&self) -> f64 { - let total = self.total_requests.load(Ordering::Relaxed) as f64; - let elapsed = self.start_time.elapsed().as_secs_f64(); - if elapsed > 0.0 { - total / elapsed - } else { - 0.0 + pub(crate) async fn decrease_connection_count(&self, addr: SocketAddr) { + let mut active_conns = self.active_connections.lock().await; + if let Some(count) = active_conns.get_mut(&addr) { + *count = count.saturating_sub(1); + if *count == 0 { + active_conns.remove(&addr); + } } - } - - pub fn record_requests(&self) { - self.total_requests.fetch_add(1, Ordering::Relaxed); - } - pub fn record_errors(&self) { - self.error_count.fetch_add(1, Ordering::Relaxed); + // Notify stats manager about disconnection + if let Err(e) = self + .stats_tx + .send(StatEvent::ClientDisconnected(addr)) + .await + { + error!("Failed to send disconnection event to stats manager: {}", e); + } } - pub async fn record_response_time(&self, duration: Duration) { - let mut times = self.response_times.lock().await; - if times.len() >= 100 { - times.pop_front(); - } - times.push_back(duration); + pub fn stats_tx(&self) -> mpsc::Sender { + self.stats_tx.clone() } } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 592750b..4f11886 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,31 +1,35 @@ mod backoff_strategy; +mod events; mod guard; mod manager; mod stats; -mod events; pub use backoff_strategy::BackoffStrategy; +pub use events::StatEvent; pub use guard::ConnectionGuard; pub use manager::Manager as ConnectionManager; pub use stats::ClientStats; pub use stats::ConnectionStats; pub use stats::IpStats; -pub use events::StatEvent; #[cfg(test)] mod tests { - use tokio::time::sleep; + use tokio::{ + sync::{broadcast, mpsc, Mutex}, + time::sleep, + }; use crate::{ config::{BackoffConfig, ConnectionConfig}, - ConnectionError, RelayError, + ConnectionError, RelayError, StatsConfig, StatsManager, }; use super::*; use std::{ + collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, - time::{Duration, Instant}, + time::Duration, }; #[tokio::test] @@ -39,7 +43,8 @@ mod tests { backoff: BackoffConfig::default(), }; - let manager = Arc::new(ConnectionManager::new(config)); + let (stats_tx, _) = mpsc::channel(100); + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); // First connection should succeed @@ -69,7 +74,22 @@ mod tests { ..Default::default() }; - let manager = Arc::new(ConnectionManager::new(config)); + let stats_config = StatsConfig::default(); + + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + let stats_handle = tokio::spawn({ + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); + + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); // First connection succeeds @@ -79,7 +99,8 @@ mod tests { let _err = manager.accept_connection(addr).await.unwrap_err(); // Check stats - let stats = manager.get_stats().await.unwrap(); // Unwrap Result + let stats = manager.get_stats().await.unwrap(); + assert_eq!( stats.active_connections, 1, "Should have one active connection" @@ -91,20 +112,9 @@ mod tests { // Cleanup drop(conn); - } - - #[tokio::test] - async fn test_error_recording() { - let manager = Arc::new(ConnectionManager::new(ConnectionConfig::default())); - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); - - // Record some errors - assert!(manager.record_client_error(&addr).await.is_ok()); - // Verify error was recorded - let stats = manager.get_stats().await.unwrap(); - assert_eq!(stats.total_errors, 1); - assert!(stats.per_ip_stats.get(&addr).unwrap().error_count == 1); + shutdown_tx.send(()).unwrap(); + stats_handle.await.unwrap(); } #[tokio::test] @@ -114,7 +124,26 @@ mod tests { ..Default::default() }; - let manager = Arc::new(ConnectionManager::new(config)); + let stats_config = StatsConfig { + cleanup_interval: config.idle_timeout, + idle_timeout: config.idle_timeout, + error_timeout: config.error_timeout, + max_events_per_second: 10000, // TODO(aljen): Make configurable + }; + + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + let stats_handle = tokio::spawn({ + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); + + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); // Create a connection @@ -133,37 +162,36 @@ mod tests { // Verify connection was cleaned up let stats = manager.get_stats().await.unwrap(); assert_eq!(stats.active_connections, 0); + + shutdown_tx.send(()).unwrap(); + stats_handle.await.unwrap(); } #[tokio::test] - async fn test_stats_counter_overflow() { - let manager = Arc::new(ConnectionManager::new(ConnectionConfig::default())); - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); + async fn test_connection_guard_cleanup() { + let config = ConnectionConfig::default(); - // Manually set counters to near max to test overflow protection - { - let mut stats = manager.stats.lock().await; - let client_stats = stats.entry(addr).or_insert_with(|| ClientStats { - active_connections: usize::MAX, - last_active: Instant::now(), - total_requests: u64::MAX, - error_count: u64::MAX, - last_error: None, - }); - client_stats.active_connections = usize::MAX; - } + let stats_config = StatsConfig { + cleanup_interval: config.idle_timeout, + idle_timeout: config.idle_timeout, + error_timeout: config.error_timeout, + max_events_per_second: 10000, // TODO(aljen): Make configurable + }; - // Attempting to increment should result in error - let result = manager.accept_connection(addr).await; - assert!(matches!( - result.unwrap_err(), - RelayError::Connection(ConnectionError::InvalidState(_)) - )); - } + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + let stats_handle = tokio::spawn({ + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); + + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); - #[tokio::test] - async fn test_connection_guard_cleanup() { - let manager = Arc::new(ConnectionManager::new(ConnectionConfig::default())); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234); { @@ -180,6 +208,9 @@ mod tests { let stats = manager.get_stats().await.unwrap(); assert_eq!(stats.active_connections, 0); + + shutdown_tx.send(()).unwrap(); + stats_handle.await.unwrap(); } #[tokio::test] @@ -205,4 +236,46 @@ mod tests { strategy.reset(); assert_eq!(strategy.next_backoff().unwrap().as_millis(), 100); } + + #[tokio::test] + async fn test_connection_lifecycle() { + let config = ConnectionConfig::default(); + let (stats_tx, mut stats_rx) = mpsc::channel(100); + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); + + // Handle stats events in background + tokio::spawn(async move { + while let Some(event) = stats_rx.recv().await { + match event { + StatEvent::QueryConnectionStats { response_tx } => { + let _ = response_tx.send(ConnectionStats { + total_connections: 1, + active_connections: 1, + total_requests: 0, + total_errors: 0, + requests_per_second: 0.0, + avg_response_time_ms: 0, + per_ip_stats: HashMap::new(), + }); + } + _ => {} + } + } + }); + + let addr = "127.0.0.1:8080".parse().unwrap(); + + // Test connection acceptance + let guard = manager.accept_connection(addr).await.unwrap(); + assert_eq!(manager.get_connection_count(&addr).await, 1); + + // Test statistics + let stats = manager.get_stats().await.unwrap(); + assert_eq!(stats.active_connections, 1); + + // Test connection cleanup + drop(guard); + sleep(Duration::from_millis(100)).await; + assert_eq!(manager.get_connection_count(&addr).await, 0); + } } diff --git a/src/connection/stats/client.rs b/src/connection/stats/client.rs index f04884d..ff40df0 100644 --- a/src/connection/stats/client.rs +++ b/src/connection/stats/client.rs @@ -1,16 +1,33 @@ -use std::time::Instant; +use std::time::SystemTime; + +use serde::Serialize; /// Stats for a single client -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct Stats { /// Number of active connections from this address pub active_connections: usize, - /// Last activity - pub last_active: Instant, /// Total number of requests pub total_requests: u64, - /// Number of errors - pub error_count: u64, + /// Total number of errors + pub total_errors: u64, + /// Last activity + pub last_active: SystemTime, /// Timestamp of the last error - pub last_error: Option, + pub last_error: Option, + /// Average response time + pub avg_response_time_ms: u64, +} + +impl Default for Stats { + fn default() -> Self { + Self { + active_connections: 0, + total_requests: 0, + total_errors: 0, + last_active: SystemTime::now(), + last_error: None, + avg_response_time_ms: 0, + } + } } diff --git a/src/connection/stats/connection.rs b/src/connection/stats/connection.rs index 3b5863e..825f3cb 100644 --- a/src/connection/stats/connection.rs +++ b/src/connection/stats/connection.rs @@ -1,12 +1,85 @@ -use std::{collections::HashMap, net::SocketAddr}; +use std::{ + collections::HashMap, + net::SocketAddr, + time::{Duration, SystemTime}, +}; -use super::IpStats; +use serde::Serialize; -#[derive(Debug)] +use super::{ClientStats, IpStats}; + +#[derive(Debug, Serialize)] pub struct Stats { pub total_connections: u64, pub active_connections: usize, pub total_requests: u64, pub total_errors: u64, + pub requests_per_second: f64, + pub avg_response_time_ms: u64, pub per_ip_stats: HashMap, } + +impl Stats { + pub fn from_client_stats(stats: &HashMap) -> Self { + let mut total_active = 0; + let mut total_requests = 0; + let mut total_errors = 0; + let mut total_response_time = 0u64; + let mut response_time_count = 0; + let mut per_ip = HashMap::new(); + + // Calculate totals and build per-IP stats + for (addr, client) in stats { + total_active += client.active_connections; + total_requests += client.total_requests; + total_errors += client.total_errors; + + if client.avg_response_time_ms > 0 { + total_response_time += client.avg_response_time_ms; + response_time_count += 1; + } + + per_ip.insert( + *addr, + IpStats { + active_connections: client.active_connections, + total_requests: client.total_requests, + total_errors: client.total_errors, + last_active: client.last_active, + last_error: client.last_error, + avg_response_time_ms: client.avg_response_time_ms, + }, + ); + } + + Self { + total_connections: total_active as u64, + active_connections: total_active, + total_requests, + total_errors, + requests_per_second: Self::calculate_requests_per_second(stats), + avg_response_time_ms: if response_time_count > 0 { + total_response_time / response_time_count + } else { + 0 + }, + per_ip_stats: per_ip, + } + } + + fn calculate_requests_per_second(stats: &HashMap) -> f64 { + let now = SystemTime::now(); + let window = Duration::from_secs(60); + let mut recent_requests = 0; + + for client in stats.values() { + if let Ok(duration) = now.duration_since(client.last_active) { + if duration <= window { + recent_requests += client.total_requests as usize; + } + } + } + + recent_requests as f64 / window.as_secs_f64() + } +} diff --git a/src/connection/stats/ip.rs b/src/connection/stats/ip.rs index 057f8e9..860bc7d 100644 --- a/src/connection/stats/ip.rs +++ b/src/connection/stats/ip.rs @@ -1,10 +1,14 @@ -use std::time::Instant; +use std::time::SystemTime; -#[derive(Debug)] +use serde::Serialize; + +/// Stats for a single IP address +#[derive(Debug, Clone, Serialize)] pub struct Stats { pub active_connections: usize, pub total_requests: u64, - pub error_count: u64, - pub last_active: Instant, - pub last_error: Option, + pub total_errors: u64, + pub last_active: SystemTime, + pub last_error: Option, + pub avg_response_time_ms: u64, } diff --git a/src/http_api.rs b/src/http_api.rs index 9752b35..811e160 100644 --- a/src/http_api.rs +++ b/src/http_api.rs @@ -1,11 +1,11 @@ -use std::sync::Arc; +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::SystemTime}; use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Json, Router}; use serde::Serialize; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, oneshot}; use tracing::info; -use crate::ConnectionManager; +use crate::{connection::StatEvent, ConnectionManager}; #[derive(Debug, Serialize)] struct HealthResponse { @@ -15,36 +15,141 @@ struct HealthResponse { } #[derive(Debug, Serialize)] -struct StatsResponse { +struct IpStatsResponse { + active_connections: usize, total_requests: u64, - active_connections: u32, - error_count: u32, + total_errors: u64, avg_response_time_ms: u64, + last_active: SystemTime, + last_error: Option, +} + +#[derive(Debug, Serialize)] +struct StatsResponse { + // Basic stats + total_connections: u64, + active_connections: u32, + total_requests: u64, + total_errors: u64, requests_per_second: f64, + avg_response_time_ms: u64, + + // Stats per IP + per_ip_stats: HashMap, } type ApiState = Arc; async fn health_handler(State(state): State) -> impl IntoResponse { - let response = HealthResponse { - status: "ok", - tcp_connections: state.connection_count().await, - rtu_status: "ok", // TODO: Implement RTU status check - }; + let (tx, rx) = oneshot::channel(); + + if (state + .stats_tx() + .send(StatEvent::QueryConnectionStats { response_tx: tx }) + .await) + .is_err() + { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(HealthResponse { + status: "error", + tcp_connections: 0, + rtu_status: "unknown", + }), + ); + } - (StatusCode::OK, Json(response)) + match rx.await { + Ok(stats) => { + ( + StatusCode::OK, + Json(HealthResponse { + status: "ok", + tcp_connections: stats.active_connections as u32, + rtu_status: "ok", // TODO(aljen): Implement RTU status check + }), + ) + } + Err(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(HealthResponse { + status: "error", + tcp_connections: 0, + rtu_status: "unknown", + }), + ), + } } async fn stats_handler(State(state): State) -> impl IntoResponse { - let response = StatsResponse { - total_requests: state.total_requests(), - active_connections: state.connection_count().await, - error_count: state.error_count(), - avg_response_time_ms: state.avg_response_time().await.as_millis() as u64, - requests_per_second: state.requests_per_second(), - }; - - (StatusCode::OK, Json(response)) + let (tx, rx) = oneshot::channel(); + + if (state + .stats_tx() + .send(StatEvent::QueryConnectionStats { response_tx: tx }) + .await) + .is_err() + { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(StatsResponse { + total_connections: 0, + active_connections: 0, + total_requests: 0, + total_errors: 0, + requests_per_second: 0.0, + avg_response_time_ms: 0, + per_ip_stats: HashMap::new(), + }), + ); + } + + match rx.await { + Ok(stats) => { + let per_ip_stats = stats + .per_ip_stats + .into_iter() + .map(|(addr, ip_stats)| { + ( + addr, + IpStatsResponse { + active_connections: ip_stats.active_connections, + total_requests: ip_stats.total_requests, + total_errors: ip_stats.total_errors, + avg_response_time_ms: ip_stats.avg_response_time_ms, + last_active: ip_stats.last_active, + last_error: ip_stats.last_error, + }, + ) + }) + .collect(); + + ( + StatusCode::OK, + Json(StatsResponse { + total_connections: stats.total_connections, + active_connections: stats.active_connections as u32, + total_requests: stats.total_requests, + total_errors: stats.total_errors, + requests_per_second: stats.requests_per_second, + avg_response_time_ms: stats.avg_response_time_ms, + per_ip_stats, + }), + ) + } + Err(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(StatsResponse { + total_connections: 0, + active_connections: 0, + total_requests: 0, + total_errors: 0, + requests_per_second: 0.0, + avg_response_time_ms: 0, + per_ip_stats: HashMap::new(), + }), + ), + } } pub async fn start_http_server( @@ -72,3 +177,88 @@ pub async fn start_http_server( Ok(()) } + +#[cfg(test)] +mod tests { + use crate::{ConnectionConfig, StatsManager}; + + use super::*; + use axum::body::Body; + use axum::http::Request; + use tokio::sync::Mutex; + use tower::ServiceExt; + + #[tokio::test] + async fn test_health_endpoint() { + // Create a test stats manager + let config = ConnectionConfig::default(); + let stats_config = crate::StatsConfig::default(); + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + let stats_handle = tokio::spawn({ + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); + + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); + + // Build test app + let app = Router::new() + .route("/health", get(health_handler)) + .with_state(manager); + + // Create test request + let req = Request::builder() + .uri("/health") + .body(Body::empty()) + .unwrap(); + + // Get response + let response = app.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + shutdown_tx.send(()).unwrap(); + stats_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_stats_endpoint() { + let config = ConnectionConfig::default(); + let stats_config = crate::StatsConfig::default(); + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + + let stats_handle = tokio::spawn({ + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); + + let manager = Arc::new(ConnectionManager::new(config, stats_tx)); + + let app = Router::new() + .route("/stats", get(stats_handler)) + .with_state(manager); + + let req = Request::builder() + .uri("/stats") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(req).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + shutdown_tx.send(()).unwrap(); + stats_handle.await.unwrap(); + } +} diff --git a/src/lib.rs b/src/lib.rs index b866bc2..8999fe9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,12 @@ pub mod http_api; pub mod modbus; pub mod modbus_relay; pub mod rtu_transport; +pub mod stats_manager; mod utils; -pub use config::{ConnectionConfig, HttpConfig, LoggingConfig, RelayConfig, RtuConfig, TcpConfig}; +pub use config::{ + ConnectionConfig, HttpConfig, LoggingConfig, RelayConfig, RtuConfig, StatsConfig, TcpConfig, +}; pub use config::{DataBits, Parity, RtsType, StopBits}; pub use connection::BackoffStrategy; pub use connection::{ClientStats, ConnectionStats, IpStats}; @@ -20,3 +23,4 @@ pub use http_api::start_http_server; pub use modbus::{guess_response_size, ModbusProcessor}; pub use modbus_relay::ModbusRelay; pub use rtu_transport::RtuTransport; +pub use stats_manager::StatsManager; diff --git a/src/modbus_relay.rs b/src/modbus_relay.rs index e5c8117..17282a6 100644 --- a/src/modbus_relay.rs +++ b/src/modbus_relay.rs @@ -7,7 +7,7 @@ use tokio::{ task::JoinHandle, time::{sleep, timeout}, }; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::{ connection::StatEvent, @@ -18,34 +18,22 @@ use crate::{ http_api::start_http_server, rtu_transport::RtuTransport, utils::generate_request_id, - ConnectionManager, IoOperation, ModbusProcessor, RelayConfig, + ConnectionManager, IoOperation, ModbusProcessor, RelayConfig, StatsConfig, StatsManager, }; use socket2::{SockRef, TcpKeepalive}; -const STATS_CHANNEL_SIZE: usize = 100; - pub struct ModbusRelay { config: RelayConfig, transport: Arc, connection_manager: Arc, - stats_manager: Arc, + stats_manager: Arc>, stats_tx: mpsc::Sender, shutdown: broadcast::Sender<()>, main_shutdown: tokio::sync::watch::Sender, tasks: Arc>>>, } -fn spawn_task(name: &str, tasks: &mut Vec>, future: F) -where - F: Future + Send + 'static, -{ - let task = tokio::spawn(future); - debug!("Spawned {} task: {:?}", name, task.id()); - - tasks.push(task); -} - impl ModbusRelay { pub fn new(config: RelayConfig) -> Result { // Validate the config first @@ -53,49 +41,24 @@ impl ModbusRelay { let transport = RtuTransport::new(&config.rtu, config.logging.trace_frames)?; - // Initialize connection managers with connection config from RelayConfig - let connection_manager = Arc::new(ConnectionManager::new(config.connection.clone())); - let stats_manager = Arc::new(ConnectionManager::new(config.connection.clone())); - - // Create channel for stats events - let (stats_tx, mut stats_rx) = mpsc::channel(STATS_CHANNEL_SIZE); - - let mut tasks = Vec::new(); - - let shutdown = broadcast::channel(1).0; - - let stats_manager_clone = Arc::clone(&stats_manager); - let mut shutdown_rx = shutdown.subscribe(); - - spawn_task("stats_manager", &mut tasks, async move { - loop { - tokio::select! { - Some(event) = stats_rx.recv() => { - match event { - StatEvent::Request { success } => { - if success { - stats_manager_clone.record_requests(); - } else { - stats_manager_clone.record_errors(); - } - } - StatEvent::ResponseTime(duration) => { - stats_manager_clone.record_response_time(duration).await; - } - } - } - _ = shutdown_rx.recv() => { - info!("Stats manager shutting down"); - break; - } - else => break - } - } - }); + // Create stats manager first + let stats_config = StatsConfig { + cleanup_interval: config.connection.idle_timeout, + idle_timeout: config.connection.idle_timeout, + error_timeout: config.connection.error_timeout, + max_events_per_second: 10000, // TODO(aljen): Make configurable + }; + let (stats_manager, stats_tx) = StatsManager::new(stats_config); + let stats_manager = Arc::new(Mutex::new(stats_manager)); - let tasks = Arc::new(Mutex::new(tasks)); + // Initialize connection manager with stats sender + let connection_manager = Arc::new(ConnectionManager::new( + config.connection.clone(), + stats_tx.clone(), + )); - let (main_shutdown, _) = tokio::sync::watch::channel(false); + let (shutdown_tx, _) = broadcast::channel(1); + let (main_shutdown_tx, _) = tokio::sync::watch::channel(false); Ok(Self { config, @@ -103,9 +66,9 @@ impl ModbusRelay { connection_manager, stats_manager, stats_tx, - shutdown, - main_shutdown, - tasks, + shutdown: shutdown_tx, + main_shutdown: main_shutdown_tx, + tasks: Arc::new(Mutex::new(Vec::new())), }) } @@ -163,7 +126,17 @@ impl ModbusRelay { } pub async fn run(self: Arc) -> Result<(), RelayError> { - let mut tasks = Vec::new(); + // Start stats manager + let shutdown_rx = self.shutdown.subscribe(); + + self.spawn_task("stats_manager", { + let stats_manager = Arc::clone(&self.stats_manager); + + async move { + let mut stats_manager = stats_manager.lock().await; + stats_manager.run(shutdown_rx).await; + } + }); // Start TCP server let tcp_server = { @@ -229,20 +202,25 @@ impl ModbusRelay { Ok::<_, RelayError>(()) }) }; - tasks.push(tcp_server); + + self.spawn_task("tcp_server", async move { + if let Err(e) = tcp_server.await { + error!("TCP server task failed: {}", e); + } + }); // Start HTTP server if enabled if self.config.http.enabled { let http_server = start_http_server( self.config.http.bind_addr.clone(), self.config.http.bind_port, - self.stats_manager.clone(), + self.connection_manager.clone(), self.shutdown.subscribe(), ); self.spawn_task("http", async move { if let Err(e) = http_server.await { - error!("HTTP server error: {}", e); + error!("HTTP server error: {}", e) } }); } @@ -252,9 +230,11 @@ impl ModbusRelay { let mut shutdown_rx = self.shutdown.subscribe(); self.spawn_task("cleanup", async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { tokio::select! { - _ = sleep(Duration::from_secs(60)) => { + _ = interval.tick() => { if let Err(e) = manager.cleanup_idle_connections().await { error!("Error during connection cleanup: {}", e); } @@ -267,27 +247,7 @@ impl ModbusRelay { } }); - // Periodically log statistics - let manager = Arc::clone(&self.stats_manager); - let mut shutdown_rx = self.shutdown.subscribe(); - - self.spawn_task("stats", async move { - loop { - tokio::select! { - _ = sleep(Duration::from_secs(300)) => { - match manager.get_stats().await { - Ok(stats) => info!("Connection stats: {:?}", stats), - Err(e) => error!("Failed to get connection stats: {}", e), - } - } - _ = shutdown_rx.recv() => { - debug!("Stats task received shutdown signal"); - break; - } - } - } - }); - + // Wait for shutdown signal let mut shutdown_rx = self.main_shutdown.subscribe(); tokio::select! { @@ -296,13 +256,6 @@ impl ModbusRelay { } } - // Wait for all tasks to complete - for task in tasks { - if let Err(e) = task.await { - error!("Task error: {}", e); - } - } - info!("Main loop exited"); Ok(()) } @@ -312,20 +265,20 @@ impl ModbusRelay { info!("Initiating graceful shutdown"); let timeout_duration = Duration::from_secs(5); + // Send main shutdown signal let _ = self.main_shutdown.send(true); // 1. Log initial state - if let Ok(stats) = self.stats_manager.get_stats().await { - info!( - "Current state: {} active connections, {} total requests", - stats.active_connections, stats.total_requests - ); - } + let stats = self.connection_manager.get_stats().await?; + info!( + "Current state: {} active connections, {} total requests", + stats.active_connections, stats.total_requests + ); - // 2. Sending shutdown signal to all tasks + // 2. Send shutdown signal to all tasks info!("Sending shutdown signal to tasks"); self.shutdown.send(()).map_err(|e| { - RelayError::Connection(ConnectionError::InvalidState(format!( + RelayError::Connection(ConnectionError::invalid_state(format!( "Failed to send shutdown signal: {}", e ))) @@ -334,18 +287,13 @@ impl ModbusRelay { // 3. Initiate connection shutdown info!("Initiating connection shutdown"); if let Err(e) = self.connection_manager.close_all_connections().await { - error!("Error initiating connection shutdown: {}", e); + error!("Error during connection shutdown: {}", e); } // 4. Wait for connections to close with timeout let start = Instant::now(); - loop { - if start.elapsed() >= timeout_duration { - error!("Timeout waiting for connections to close"); - break; - } - - if let Ok(stats) = self.stats_manager.get_stats().await { + while start.elapsed() < timeout_duration { + if let Ok(stats) = self.connection_manager.get_stats().await { if stats.active_connections == 0 { info!("All connections closed"); break; @@ -355,7 +303,7 @@ impl ModbusRelay { stats.active_connections ); } - tokio::time::sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)).await; } // 5. Now we can safely close the serial port @@ -506,34 +454,15 @@ async fn send_response( } async fn handle_client( - stream: TcpStream, + mut stream: TcpStream, peer_addr: SocketAddr, transport: Arc, manager: Arc, stats_tx: mpsc::Sender, ) -> Result<(), RelayError> { - let start_time = Instant::now(); - // Create connection guard to track this connection let _guard = manager.accept_connection(peer_addr).await?; - let result = handle_client_inner(stream, peer_addr, transport, stats_tx).await; - - if result.is_err() { - manager.record_client_error(&peer_addr).await?; - } - - manager.record_response_time(start_time.elapsed()).await; - - result -} - -async fn handle_client_inner( - mut stream: TcpStream, - peer_addr: SocketAddr, - transport: Arc, - stats_tx: mpsc::Sender, -) -> Result<(), RelayError> { let request_id = generate_request_id(); let client_span = tracing::info_span!( @@ -568,15 +497,18 @@ async fn handle_client_inner( break; } Err(e) => { - // Record TCP frame error stats_tx - .send(StatEvent::Request { success: false }) - .await - .ok(); - stats_tx - .send(StatEvent::ResponseTime(frame_start.elapsed())) + .send(StatEvent::RequestProcessed { + addr: peer_addr, + success: false, + duration_ms: frame_start.elapsed().as_millis() as u64, + }) .await + .map_err(|e| { + warn!("Failed to send stats event: {}", e); + }) .ok(); + return Err(e); } }; @@ -586,31 +518,53 @@ async fn handle_client_inner( Ok(response) => { // Record successful Modbus request stats_tx - .send(StatEvent::Request { success: true }) - .await - .ok(); - stats_tx - .send(StatEvent::ResponseTime(frame_start.elapsed())) + .send(StatEvent::RequestProcessed { + addr: peer_addr, + success: true, + duration_ms: frame_start.elapsed().as_millis() as u64, + }) .await + .map_err(|e| { + warn!("Failed to send stats event: {}", e); + }) .ok(); + response } Err(e) => { // Record failed Modbus request stats_tx - .send(StatEvent::Request { success: false }) - .await - .ok(); - stats_tx - .send(StatEvent::ResponseTime(frame_start.elapsed())) + .send(StatEvent::RequestProcessed { + addr: peer_addr, + success: false, + duration_ms: frame_start.elapsed().as_millis() as u64, + }) .await + .map_err(|e| { + warn!("Failed to send stats event: {}", e); + }) .ok(); + return Err(e); } }; // 3. Send response - send_response(&mut writer, &response, peer_addr).await? + if let Err(e) = send_response(&mut writer, &response, peer_addr).await { + stats_tx + .send(StatEvent::RequestProcessed { + addr: peer_addr, + success: false, + duration_ms: frame_start.elapsed().as_millis() as u64, + }) + .await + .map_err(|e| { + warn!("Failed to send stats event: {}", e); + }) + .ok(); + + return Err(e); + } } Ok(()) diff --git a/src/stats_manager.rs b/src/stats_manager.rs new file mode 100644 index 0000000..e8f6fee --- /dev/null +++ b/src/stats_manager.rs @@ -0,0 +1,248 @@ +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::SystemTime}; + +use tokio::sync::{broadcast, mpsc, Mutex}; +use tracing::{debug, info, warn}; + +use crate::{config::StatsConfig, connection::StatEvent, ClientStats, ConnectionStats}; + +pub struct StatsManager { + stats: Arc>>, + event_rx: mpsc::Receiver, + config: StatsConfig, + total_connections: u64, +} + +impl StatsManager { + pub fn new(config: StatsConfig) -> (Self, mpsc::Sender) { + let (tx, rx) = mpsc::channel(config.max_events_per_second as usize); + + let manager = Self { + stats: Arc::new(Mutex::new(HashMap::new())), + event_rx: rx, + config, + total_connections: 0, + }; + + (manager, tx) + } + + pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<()>) { + let mut cleanup_interval = tokio::time::interval(self.config.cleanup_interval); + + loop { + tokio::select! { + Some(event) = self.event_rx.recv() => { + self.handle_event(event).await; + } + + _ = cleanup_interval.tick() => { + self.cleanup_idle_stats().await; + } + + _ = shutdown_rx.recv() => { + info!("Stats manager shutting down"); + break; + } + } + } + } + + async fn handle_event(&mut self, event: StatEvent) { + let mut stats = self.stats.lock().await; + + match event { + StatEvent::ClientConnected(addr) => { + let client_stats = stats.entry(addr).or_default(); + client_stats.active_connections = client_stats.active_connections.saturating_add(1); + client_stats.last_active = SystemTime::now(); + self.total_connections = self.total_connections.saturating_add(1); + debug!("Client connected from {}", addr); + } + + StatEvent::ClientDisconnected(addr) => { + if let Some(client_stats) = stats.get_mut(&addr) { + client_stats.active_connections = + client_stats.active_connections.saturating_sub(1); + client_stats.last_active = SystemTime::now(); + debug!("Client disconnected from {}", addr); + } + } + + StatEvent::RequestProcessed { + addr, + success, + duration_ms, + } => { + let client_stats = stats.entry(addr).or_default(); + client_stats.total_requests = client_stats.total_requests.saturating_add(1); + + if !success { + client_stats.total_errors = client_stats.total_errors.saturating_add(1); + client_stats.last_error = Some(SystemTime::now()); + } + + // Update average response time using exponential moving average + const ALPHA: f64 = 0.1; // Smoothing factor + + if client_stats.avg_response_time_ms == 0 { + client_stats.avg_response_time_ms = duration_ms; + } else { + let current_avg = client_stats.avg_response_time_ms as f64; + client_stats.avg_response_time_ms = + (current_avg + ALPHA * (duration_ms as f64 - current_avg)) as u64; + } + + client_stats.last_active = SystemTime::now(); + } + + StatEvent::QueryStats { addr, response_tx } => { + if let Some(stats) = stats.get(&addr) { + if response_tx.send(stats.clone()).is_err() { + warn!("Failed to send stats for {}", addr); + } + } + } + + StatEvent::QueryConnectionStats { response_tx } => { + let conn_stats = ConnectionStats::from_client_stats(&stats); + if response_tx.send(conn_stats).is_err() { + warn!("Failed to send connection stats"); + } + } + } + } + + async fn cleanup_idle_stats(&self) { + let mut stats = self.stats.lock().await; + let now = SystemTime::now(); + + stats.retain(|addr, client_stats| { + // Check if client has been idle for too long + let is_idle = now + .duration_since(client_stats.last_active) + .map(|idle_time| idle_time <= self.config.idle_timeout) + .unwrap_or(true); + + // Check if there was an error that's old enough to clean up + let has_recent_error = client_stats + .last_error + .and_then(|last_error| now.duration_since(last_error).ok()) + .map(|error_time| error_time <= self.config.error_timeout) + .unwrap_or(false); + + let should_retain = is_idle || has_recent_error; + + if !should_retain { + debug!( + "Cleaning up stats for {}: {} connections, {} requests, {} errors", + addr, + client_stats.active_connections, + client_stats.total_requests, + client_stats.total_errors + ); + } + + should_retain + }); + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use tokio::{sync::oneshot, time::sleep}; + + #[tokio::test] + async fn test_client_lifecycle() { + let config = StatsConfig::default(); + let (mut manager, tx) = StatsManager::new(config); + let addr = "127.0.0.1:8080".parse().unwrap(); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + let manager_handle = tokio::spawn(async move { + manager.run(shutdown_rx).await; + }); + + // Test connection + tx.send(StatEvent::ClientConnected(addr)).await.unwrap(); + + // Test successful request + tx.send(StatEvent::RequestProcessed { + addr, + success: true, + duration_ms: Duration::from_millis(100).as_millis() as u64, + }) + .await + .unwrap(); + + // Test failed request + tx.send(StatEvent::RequestProcessed { + addr, + success: false, + duration_ms: Duration::from_millis(150).as_millis() as u64, + }) + .await + .unwrap(); + + sleep(Duration::from_millis(100)).await; + + // Query per-client stats + let (response_tx, response_rx) = oneshot::channel(); + tx.send(StatEvent::QueryStats { addr, response_tx }) + .await + .unwrap(); + + let stats = response_rx.await.unwrap(); + assert_eq!(stats.active_connections, 1); + assert_eq!(stats.total_requests, 2); + assert_eq!(stats.total_errors, 1); + + // Query global stats + let (response_tx, response_rx) = oneshot::channel(); + tx.send(StatEvent::QueryConnectionStats { response_tx }) + .await + .unwrap(); + + let conn_stats = response_rx.await.unwrap(); + assert_eq!(conn_stats.total_requests, 2); + assert_eq!(conn_stats.total_errors, 1); + + // Cleanup + shutdown_tx.send(()).unwrap(); + manager_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_cleanup_idle_stats() { + let mut config = StatsConfig::default(); + config.idle_timeout = Duration::from_millis(100); + let (mut manager, tx) = StatsManager::new(config); + let addr = "127.0.0.1:8080".parse().unwrap(); + + let (shutdown_tx, shutdown_rx) = broadcast::channel(1); + let manager_handle = tokio::spawn(async move { + manager.run(shutdown_rx).await; + }); + + // Add client and disconnect + tx.send(StatEvent::ClientConnected(addr)).await.unwrap(); + tx.send(StatEvent::ClientDisconnected(addr)).await.unwrap(); + + // Wait for idle timeout + sleep(Duration::from_millis(200)).await; + + // Query stats - should be cleaned up + let (response_tx, response_rx) = oneshot::channel(); + tx.send(StatEvent::QueryConnectionStats { response_tx }) + .await + .unwrap(); + + let conn_stats = response_rx.await.unwrap(); + assert_eq!(conn_stats.active_connections, 0); + + shutdown_tx.send(()).unwrap(); + manager_handle.await.unwrap(); + } +}