diff --git a/src/lib.rs b/src/lib.rs index 6fc8d10..375de7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,7 @@ impl Manager { bind: String, store_addr: String, world_size: u64, + heartbeat_interval: Duration, ) -> PyResult { py.allow_threads(move || { let runtime = Runtime::new()?; @@ -55,6 +56,7 @@ impl Manager { bind, store_addr, world_size, + heartbeat_interval, )) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); @@ -224,6 +226,7 @@ struct Lighthouse { #[pymethods] impl Lighthouse { + #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] #[new] fn new( py: Python<'_>, @@ -231,9 +234,11 @@ impl Lighthouse { min_replicas: u64, join_timeout_ms: Option, quorum_tick_ms: Option, + heartbeat_timeout_ms: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); + let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); py.allow_threads(move || { let rt = Runtime::new()?; @@ -244,6 +249,7 @@ impl Lighthouse { min_replicas: min_replicas, join_timeout_ms: join_timeout_ms, quorum_tick_ms: quorum_tick_ms, + heartbeat_timeout_ms: heartbeat_timeout_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index c22fe23..c195b99 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -38,6 +38,7 @@ use crate::torchftpb::{ LighthouseQuorumResponse, Quorum, QuorumMember, }; +#[derive(Clone)] struct QuorumMemberDetails { joined: Instant, member: QuorumMember, @@ -69,17 +70,39 @@ pub struct Lighthouse { #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. - #[structopt(long = "bind", default_value = "[::]:29510")] + #[structopt( + long = "bind", + default_value = "[::]:29510", + help = "Address to bind the server to" + )] pub bind: String, - #[structopt(long = "join_timeout_ms", default_value = "60000")] + #[structopt( + long = "join_timeout_ms", + default_value = "60000", + help = "How long to wait for new replicas to join before considering a quorum" + )] pub join_timeout_ms: u64, - #[structopt(long = "min_replicas")] + #[structopt( + long = "min_replicas", + help = "Minimum number of replicas to consider a quorum" + )] pub min_replicas: u64, - #[structopt(long = "quorum_tick_ms", default_value = "100")] + #[structopt( + long = "quorum_tick_ms", + default_value = "100", + help = "How frequently to check for quorum when waiting for workers." + )] pub quorum_tick_ms: u64, + + #[structopt( + long = "heartbeat_timeout_ms", + default_value = "5000", + help = "how long to wait for a heartbeat before considering a replica dead." + )] + pub heartbeat_timeout_ms: u64, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -90,55 +113,83 @@ fn quorum_changed(a: &Vec, b: &Vec) -> bool { } // Checks whether the quorum is valid and an explanation for the state. -fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) { - let mut first_joined = Instant::now(); +fn quorum_valid( + now: Instant, + heartbeats: &HashMap, + state: &RoomState, + opt: &LighthouseOpt, +) -> (bool, String) { + let mut first_joined = now; + + let healthy_participants: HashMap = state + .participants + .clone() + .into_iter() + .filter(|(replica_id, _details)| { + let last_heartbeat = heartbeats.get(replica_id); + if last_heartbeat.is_none() { + return false; + } + + now.duration_since(*last_heartbeat.unwrap()) + < Duration::from_millis(opt.heartbeat_timeout_ms) + }) + .collect(); - for details in state.participants.values() { + for details in healthy_participants.values() { if details.joined < first_joined { first_joined = details.joined; } } + let metadata = format!( + "[{}/{} participants healthy]", + healthy_participants.len(), + state.participants.len() + ); + if state.prev_quorum.is_some() { let mut is_fast_quorum = true; let prev_quorum = state.prev_quorum.as_ref().unwrap(); for prev_member in prev_quorum.participants.iter() { - if !state.participants.contains_key(&prev_member.replica_id) { + if !healthy_participants.contains_key(&prev_member.replica_id) { is_fast_quorum = false; break; } } if is_fast_quorum { - return (is_fast_quorum, format!("Fast quorum found!")); + return (is_fast_quorum, format!("Fast quorum found! {}", metadata)); } } - if state.participants.len() < opt.min_replicas as usize { + if healthy_participants.len() < opt.min_replicas as usize { return ( false, format!( - "No quorum, only have {} participants, need {}", - state.participants.len(), - opt.min_replicas + "No quorum, only have {} participants, need {} {}", + healthy_participants.len(), + opt.min_replicas, + metadata ), ); } // Quorum is valid at this point but lets wait for stragglers. - if Instant::now().duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { + if now.duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { return ( false, format!( - "Valid quorum with {} participants, waiting for stragglers due to join timeout", - state.participants.len() + "Valid quorum with {} participants, waiting for stragglers due to join timeout {}", + healthy_participants.len(), + metadata ), ); } - (true, format!("Valid quorum found")) + (true, format!("Valid quorum found {}", metadata)) } impl Lighthouse { @@ -155,8 +206,12 @@ impl Lighthouse { })) } - fn _quorum_tick(self: Arc, state: &mut RoomState) -> Result<()> { - let (quorum_met, reason) = quorum_valid(state, &self.opt); + fn _quorum_tick( + self: Arc, + heartbeats: &HashMap, + state: &mut RoomState, + ) -> Result<()> { + let (quorum_met, reason) = quorum_valid(Instant::now(), heartbeats, state, &self.opt); info!("{}: {}", state.room_id, reason); if quorum_met { @@ -206,8 +261,9 @@ impl Lighthouse { loop { { let mut state = self.state.lock().await; + let heartbeats = state.heartbeats.clone(); for (_room_id, room) in &mut state.rooms { - self.clone()._quorum_tick(room)?; + self.clone()._quorum_tick(&heartbeats, room)?; } } @@ -284,7 +340,8 @@ impl Lighthouse { .rooms .iter() .map(|(room_id, room)| { - let (_, quorum_status) = quorum_valid(&room, &self.opt); + let (_, quorum_status) = + quorum_valid(Instant::now(), &state.heartbeats, &room, &self.opt); let max_step = { if let Some(quorum) = room.prev_quorum.clone() { @@ -314,7 +371,7 @@ impl Lighthouse { rooms: rooms, heartbeats: state.heartbeats.clone(), old_age_threshold: Instant::now() - .checked_sub(Duration::from_secs(1)) + .checked_sub(Duration::from_millis(self.opt.heartbeat_timeout_ms)) .unwrap_or(Instant::now()), } }; @@ -367,6 +424,13 @@ impl LighthouseService for Arc { let mut rx = { let mut state = self.state.lock().await; + // implicit heartbeat + state + .heartbeats + .insert(requester.replica_id.clone(), Instant::now()); + + let heartbeats = state.heartbeats.clone(); + if !state.rooms.contains_key(&room_id) { let (tx, _) = broadcast::channel(16); @@ -395,7 +459,7 @@ impl LighthouseService for Arc { // proactively run quorum tick self.clone() - ._quorum_tick(room) + ._quorum_tick(&heartbeats, room) .map_err(|e| Status::from_error(e.into()))?; rx @@ -497,6 +561,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let mut state = RoomState { @@ -506,13 +571,16 @@ mod tests { prev_quorum: None, quorum_id: 0, }; + let mut heartbeats = HashMap::new(); - assert!(!quorum_valid(&state, &opt).0); + let now = Instant::now(); + + assert!(!quorum_valid(now, &heartbeats, &state, &opt).0); state.participants.insert( "a".to_string(), QuorumMemberDetails { - joined: Instant::now(), + joined: now, member: QuorumMember { replica_id: "a".to_string(), address: "".to_string(), @@ -522,13 +590,61 @@ mod tests { }, }, ); + heartbeats.insert("a".to_string(), now); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_valid(now, &heartbeats, &state, &opt).0); state.participants.get_mut("a").unwrap().joined = - Instant::now().sub(Duration::from_secs(10 * 60 * 60)); + now.sub(Duration::from_secs(10 * 60 * 60)); - assert!(quorum_valid(&state, &opt).0); + assert!(quorum_valid(now, &heartbeats, &state, &opt).0); + + Ok(()) + } + + #[tokio::test] + async fn test_quorum_heartbeats() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + }; + + let mut state = RoomState { + room_id: "test".to_string(), + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + }; + let mut heartbeats = HashMap::new(); + + let now = Instant::now(); + + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: now, + member: QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + }, + }, + ); + heartbeats.insert("a".to_string(), now); + + assert!(quorum_valid(now, &heartbeats, &state, &opt).0); + + // expired heartbeat + heartbeats.insert("a".to_string(), now.sub(Duration::from_secs(10))); + + let (quorum_met, reason) = quorum_valid(now, &heartbeats, &state, &opt); + assert!(!quorum_met, "{}", reason); Ok(()) } @@ -540,6 +656,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let mut state = RoomState { @@ -549,13 +666,16 @@ mod tests { prev_quorum: None, quorum_id: 0, }; + let mut heartbeats = HashMap::new(); + + let now = Instant::now(); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_valid(now, &heartbeats, &state, &opt).0); state.participants.insert( "a".to_string(), QuorumMemberDetails { - joined: Instant::now(), + joined: now, member: QuorumMember { replica_id: "a".to_string(), address: "".to_string(), @@ -565,8 +685,9 @@ mod tests { }, }, ); + heartbeats.insert("a".to_string(), now); - assert!(!quorum_valid(&state, &opt).0); + assert!(!quorum_valid(now, &heartbeats, &state, &opt).0); state.prev_quorum = Some(Quorum { quorum_id: 1, @@ -580,7 +701,7 @@ mod tests { created: Some(SystemTime::now().into()), }); - assert!(quorum_valid(&state, &opt).0); + assert!(quorum_valid(now, &heartbeats, &state, &opt).0); Ok(()) } @@ -592,6 +713,7 @@ mod tests { bind: "[::]:0".to_string(), join_timeout_ms: 1, quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, }; let lighthouse = Lighthouse::new(opt).await?; diff --git a/src/manager.rs b/src/manager.rs index 6200b27..f03e10e 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -59,6 +59,7 @@ pub struct Manager { state: Mutex, listener: Mutex>, local_addr: SocketAddr, + heartbeat_interval: Duration, } pub async fn manager_client_new( @@ -84,6 +85,7 @@ impl Manager { bind: String, store_addr: String, world_size: u64, + heartbeat_interval: Duration, ) -> Result> { let listener = tokio::net::TcpListener::bind(&bind).await?; @@ -95,6 +97,7 @@ impl Manager { address: address, store_address: store_addr, world_size: world_size, + heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_servers: HashMap::new(), rooms: HashMap::new(), @@ -156,7 +159,7 @@ impl Manager { let _response = client.heartbeat(request).await; - sleep(Duration::from_millis(100)).await; + sleep(self.heartbeat_interval).await; } } @@ -421,7 +424,8 @@ mod tests { "addr".to_string(), "[::]:29531".to_string(), "store_addr".to_string(), - 2, + 2, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager._run_grpc()); @@ -454,6 +458,7 @@ mod tests { join_timeout_ms: 100, min_replicas: 1, quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -464,7 +469,8 @@ mod tests { "addr".to_string(), "[::]:0".to_string(), "store_addr".to_string(), - 1, // world size + 1, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); @@ -502,6 +508,7 @@ mod tests { join_timeout_ms: 100, min_replicas: 2, quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -518,7 +525,8 @@ mod tests { "addr".to_string(), "[::]:0".to_string(), "store_addr".to_string(), - 1, // world size + 1, // world size + Duration::from_millis(100), // heartbeat interval ) .await?; let manager_fut = tokio::spawn(manager.clone().run()); diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index f6efc32..36ab62c 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -91,3 +91,11 @@ def test_join_timeout_behavior(self) -> None: lighthouse.shutdown() if "manager" in locals(): manager.shutdown() + + def test_heartbeat_timeout_ms_sanity(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=1, + heartbeat_timeout_ms=100, + ) + lighthouse.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index f631daf..ebab626 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -94,6 +94,7 @@ def __init__( store_port: Optional[int] = None, lighthouse_addr: Optional[str] = None, replica_id: Optional[str] = None, + heartbeat_interval: timedelta = timedelta(milliseconds=100), ) -> None: """ Args: @@ -161,6 +162,7 @@ def _manager_state_dict() -> Dict[str, T]: bind=bind, store_addr=f"{store_addr}:{store_port}", world_size=world_size, + heartbeat_interval=heartbeat_interval, ) self._store.set(MANAGER_ADDR_KEY, addr) @@ -413,6 +415,7 @@ def _async_quorum(self, room_id: str, allow_heal: bool) -> None: self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. + # TODO: handle configure errors self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index 644a4ea..6cc1a9f 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -18,11 +18,19 @@ class Manager: bind: str, store_addr: str, world_size: int, + heartbeat_interval: timedelta, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... class Lighthouse: - def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None) -> None: ... + def __init__( + self, + bind: str, + min_replicas: int, + join_timeout_ms: Optional[int] = None, + quorum_tick_ms: Optional[int] = None, + heartbeat_timeout_ms: Optional[int] = None, + ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ...