diff --git a/nativelink-scheduler/src/awaited_action_db/mod.rs b/nativelink-scheduler/src/awaited_action_db/mod.rs index c33ae61a4..7eb1287e7 100644 --- a/nativelink-scheduler/src/awaited_action_db/mod.rs +++ b/nativelink-scheduler/src/awaited_action_db/mod.rs @@ -133,7 +133,7 @@ pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { fn changed(&mut self) -> impl Future> + Send; /// Get the current awaited action. - fn borrow(&self) -> AwaitedAction; + fn borrow(&self) -> impl Future> + Send; } /// A trait that defines the interface for an AwaitedActionDb. @@ -149,7 +149,9 @@ pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static { /// Get all AwaitedActions. This call should be avoided as much as possible. fn get_all_awaited_actions( &self, - ) -> impl Future> + Send> + Send; + ) -> impl Future< + Output = Result> + Send, Error>, + > + Send; /// Get the AwaitedAction by the operation id. fn get_by_operation_id( @@ -164,7 +166,9 @@ pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static { start: Bound, end: Bound, desc: bool, - ) -> impl Future> + Send> + Send; + ) -> impl Future< + Output = Result> + Send, Error>, + > + Send; /// Process a change changed AwaitedAction and notify any listeners. fn update_awaited_action( diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs index 2815af11f..b06015b47 100644 --- a/nativelink-scheduler/src/memory_awaited_action_db.rs +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -207,14 +207,14 @@ where Ok(awaited_action) } - fn borrow(&self) -> AwaitedAction { + async fn borrow(&self) -> Result { let mut awaited_action = self.awaited_action_rx.borrow().clone(); if let Some(client_info) = self.client_info.as_ref() { let mut state = awaited_action.state().as_ref().clone(); state.client_operation_id = client_info.client_operation_id.clone(); awaited_action.set_state(Arc::new(state), None); } - awaited_action + Ok(awaited_action) } } @@ -504,22 +504,25 @@ impl I + Clone + Send + Sync> AwaitedActionDbI &'a self, state: SortedAwaitedActionState, range: impl RangeBounds + 'b, - ) -> impl DoubleEndedIterator< - Item = Result< - ( - &'a SortedAwaitedAction, - MemoryAwaitedActionSubscriber, - ), - Error, - >, - > + 'a { + ) -> Result< + impl DoubleEndedIterator< + Item = Result< + ( + &'a SortedAwaitedAction, + MemoryAwaitedActionSubscriber, + ), + Error, + >, + > + 'a, + Error, + > { let btree = match state { SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check, SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued, SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing, SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed, }; - btree.range(range).map(|sorted_awaited_action| { + Ok(btree.range(range).map(|sorted_awaited_action| { let operation_id = &sorted_awaited_action.operation_id; self.get_by_operation_id(operation_id) .ok_or_else(|| { @@ -530,7 +533,7 @@ impl I + Clone + Send + Sync> AwaitedActionDbI ) }) .map(|subscriber| (sorted_awaited_action, subscriber)) - }) + })) } fn process_state_changes_for_hash_key_map( @@ -878,8 +881,10 @@ impl I + Clone + Send + Sync + 'static> Awaite .await } - async fn get_all_awaited_actions(&self) -> impl Stream> { - ChunkedStream::new( + async fn get_all_awaited_actions( + &self, + ) -> Result>, Error> { + Ok(ChunkedStream::new( Bound::Unbounded, Bound::Unbounded, move |start, end, mut output| async move { @@ -896,7 +901,7 @@ impl I + Clone + Send + Sync + 'static> Awaite Ok(maybe_new_start .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output))) }, - ) + )) } async fn get_by_operation_id( @@ -912,38 +917,44 @@ impl I + Clone + Send + Sync + 'static> Awaite start: Bound, end: Bound, desc: bool, - ) -> impl Stream> + Send { - ChunkedStream::new(start, end, move |start, end, mut output| async move { - let inner = self.inner.lock().await; - let mut done = true; - let mut new_start = start.as_ref(); - let mut new_end = end.as_ref(); - - let iterator = inner.get_range_of_actions(state, (start.as_ref(), end.as_ref())); - // TODO(allada) This should probably use the `.left()/right()` pattern, - // but that doesn't exist in the std or any libraries we use. - if desc { - for result in iterator.rev() { - let (sorted_awaited_action, item) = - result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; - output.push_back(item); - new_end = Bound::Excluded(sorted_awaited_action); - done = false; + ) -> Result> + Send, Error> { + Ok(ChunkedStream::new( + start, + end, + move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut done = true; + let mut new_start = start.as_ref(); + let mut new_end = end.as_ref(); + + let iterator = inner + .get_range_of_actions(state, (start.as_ref(), end.as_ref())) + .err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + // TODO(allada) This should probably use the `.left()/right()` pattern, + // but that doesn't exist in the std or any libraries we use. + if desc { + for result in iterator.rev() { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_end = Bound::Excluded(sorted_awaited_action); + done = false; + } + } else { + for result in iterator { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_start = Bound::Excluded(sorted_awaited_action); + done = false; + } } - } else { - for result in iterator { - let (sorted_awaited_action, item) = - result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; - output.push_back(item); - new_start = Bound::Excluded(sorted_awaited_action); - done = false; + if done { + return Ok(None); } - } - if done { - return Ok(None); - } - Ok(Some(((new_start.cloned(), new_end.cloned()), output))) - }) + Ok(Some(((new_start.cloned(), new_end.cloned()), output))) + }, + )) } async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs index 98e526448..e98862382 100644 --- a/nativelink-scheduler/src/simple_scheduler_state_manager.rs +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -18,7 +18,7 @@ use std::time::{Duration, SystemTime}; use async_lock::Mutex; use async_trait::async_trait; -use futures::{future, stream, StreamExt, TryStreamExt}; +use futures::{future, stream, FutureExt, StreamExt, TryStreamExt}; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; use nativelink_util::action_messages::{ @@ -222,7 +222,13 @@ where NowFn: Fn() -> I + Clone + Send + Unpin + Sync + 'static, { async fn as_state(&self) -> Result, Error> { - Ok(self.awaited_action_sub.borrow().state().clone()) + Ok(self + .awaited_action_sub + .borrow() + .await + .err_tip(|| "In MatchingEngineActionStateResult::as_state")? + .state() + .clone()) } async fn changed(&mut self) -> Result, Error> { @@ -239,7 +245,11 @@ where } } - let awaited_action = self.awaited_action_sub.borrow(); + let awaited_action = self + .awaited_action_sub + .borrow() + .await + .err_tip(|| "In MatchingEngineActionStateResult::changed")?; if matches!(awaited_action.state().stage, ActionStage::Queued) { // Actions in queued state do not get periodically updated, @@ -278,7 +288,13 @@ where } async fn as_action_info(&self) -> Result, Error> { - Ok(self.awaited_action_sub.borrow().action_info().clone()) + Ok(self + .awaited_action_sub + .borrow() + .await + .err_tip(|| "In MatchingEngineActionStateResult::as_action_info")? + .action_info() + .clone()) } } @@ -362,7 +378,10 @@ where format!("Operation id {operation_id} does not exist in SimpleSchedulerStateManager::timeout_operation_id") })?; - let awaited_action = awaited_action_subscriber.borrow(); + let awaited_action = awaited_action_subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::timeout_operation_id")?; // If the action is not executing, we should not timeout the action. if !matches!(awaited_action.state().stage, ActionStage::Executing) { @@ -420,7 +439,10 @@ where None => return Ok(()), }; - let mut awaited_action = awaited_action_subscriber.borrow(); + let mut awaited_action = awaited_action_subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::update_operation")?; // Make sure the worker id matches the awaited action worker id. // This might happen if the worker sending the update is not the @@ -574,67 +596,81 @@ where } if let Some(operation_id) = &filter.operation_id { - return Ok(self + let maybe_subscriber = self .action_db .get_by_operation_id(operation_id) .await - .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? - .filter(|awaited_action_rx| { - let awaited_action = awaited_action_rx.borrow(); - apply_filter_predicate(&awaited_action, &filter) - }) - .map(|awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(awaited_action) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + let Some(subscriber) = maybe_subscriber else { + return Ok(Box::pin(stream::empty())); + }; + let awaited_action = subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + if !apply_filter_predicate(&awaited_action, &filter) { + return Ok(Box::pin(stream::empty())); + } + return Ok(Box::pin(stream::once(async move { + to_action_state_result(subscriber) + }))); } if let Some(client_operation_id) = &filter.client_operation_id { - return Ok(self + let maybe_subscriber = self .action_db .get_awaited_action_by_id(client_operation_id) .await - .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? - .filter(|awaited_action_rx| { - let awaited_action = awaited_action_rx.borrow(); - apply_filter_predicate(&awaited_action, &filter) - }) - .map(|awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(awaited_action) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + let Some(subscriber) = maybe_subscriber else { + return Ok(Box::pin(stream::empty())); + }; + let awaited_action = subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + if !apply_filter_predicate(&awaited_action, &filter) { + return Ok(Box::pin(stream::empty())); + } + return Ok(Box::pin(stream::once(async move { + to_action_state_result(subscriber) + }))); } let Some(sorted_awaited_action_state) = sorted_awaited_action_state_for_flags(filter.stages) else { - let mut all_items: Vec = self + let mut all_items: Vec<_> = self .action_db .get_all_awaited_actions() .await - .try_filter(|awaited_action_subscriber| { - future::ready(apply_filter_predicate( - &awaited_action_subscriber.borrow(), - &filter, - )) + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")? + .and_then(|awaited_action_subscriber| async move { + let awaited_action = awaited_action_subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + Ok((awaited_action_subscriber, awaited_action)) + }) + .try_filter_map(|(subscriber, awaited_action)| { + if apply_filter_predicate(&awaited_action, &filter) { + future::ready(Ok(Some((subscriber, awaited_action.sort_key())))) + .left_future() + } else { + future::ready(Result::<_, Error>::Ok(None)).right_future() + } }) .try_collect() .await - .err_tip(|| "In MemorySchedulerStateManager::filter_operations")?; + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; match filter.order_by_priority_direction { - Some(OrderDirection::Asc) => { - all_items.sort_unstable_by_key(|a| a.borrow().sort_key()) - } - Some(OrderDirection::Desc) => { - all_items.sort_unstable_by_key(|a| std::cmp::Reverse(a.borrow().sort_key())) - } + Some(OrderDirection::Asc) => all_items.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)), + Some(OrderDirection::Desc) => all_items.sort_unstable_by(|(_, a), (_, b)| b.cmp(a)), None => {} } return Ok(Box::pin(stream::iter( - all_items.into_iter().map(to_action_state_result), + all_items + .into_iter() + .map(move |(subscriber, _)| to_action_state_result(subscriber)), ))); }; @@ -642,7 +678,6 @@ where filter.order_by_priority_direction, Some(OrderDirection::Desc) ); - let filter = filter.clone(); let stream = self .action_db .get_range_of_actions( @@ -652,7 +687,21 @@ where desc, ) .await - .try_filter(move |sub| future::ready(apply_filter_predicate(&sub.borrow(), &filter))) + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")? + .and_then(|awaited_action_subscriber| async move { + let awaited_action = awaited_action_subscriber + .borrow() + .await + .err_tip(|| "In SimpleSchedulerStateManager::filter_operations")?; + Ok((awaited_action_subscriber, awaited_action)) + }) + .try_filter_map(move |(subscriber, awaited_action)| { + if apply_filter_predicate(&awaited_action, &filter) { + future::ready(Ok(Some(subscriber))).left_future() + } else { + future::ready(Result::<_, Error>::Ok(None)).right_future() + } + }) .map(move |result| -> Box { result.map_or_else( |e| -> Box { Box::new(ErrorActionStateResult(e)) }, diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index cdd9a9c90..d74770782 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -869,11 +869,11 @@ impl AwaitedActionSubscriber for MockAwaitedActionSubscriber { unreachable!(); } - fn borrow(&self) -> AwaitedAction { - AwaitedAction::new( + async fn borrow(&self) -> Result { + Ok(AwaitedAction::new( OperationId::default(), make_base_action_info(SystemTime::UNIX_EPOCH, DigestInfo::zero_digest()), - ) + )) } } @@ -933,8 +933,8 @@ impl AwaitedActionDb for MockAwaitedAction { async fn get_all_awaited_actions( &self, - ) -> impl Stream> + Send { - futures::stream::empty() + ) -> Result> + Send, Error> { + Ok(futures::stream::empty()) } async fn get_by_operation_id( @@ -953,12 +953,12 @@ impl AwaitedActionDb for MockAwaitedAction { _start: Bound, _end: Bound, _desc: bool, - ) -> impl Stream> + Send { + ) -> Result> + Send, Error> { let mut rx_get_range_of_actions = self.rx_get_range_of_actions.lock().await; let items = rx_get_range_of_actions .try_recv() .expect("Could not receive msg in mpsc"); - futures::stream::iter(items) + Ok(futures::stream::iter(items)) } async fn update_awaited_action(&self, _new_awaited_action: AwaitedAction) -> Result<(), Error> {