Skip to content

Commit

Permalink
Change AwaitedAction's API to always return Result<Stream>
Browse files Browse the repository at this point in the history
Cosmetic change to make it easier to support databases that are async in
nature. This makes all operations in the AwaitedAction API to be able to
return a result before returning a stream or other items.

towards #359
  • Loading branch information
allada committed Sep 5, 2024
1 parent 00fa82d commit 128aa7f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 102 deletions.
10 changes: 7 additions & 3 deletions nativelink-scheduler/src/awaited_action_db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static {
fn changed(&mut self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;

/// Get the current awaited action.
fn borrow(&self) -> AwaitedAction;
fn borrow(&self) -> impl Future<Output = Result<AwaitedAction, Error>> + Send;
}

/// A trait that defines the interface for an AwaitedActionDb.
Expand All @@ -149,7 +149,9 @@ pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static {
/// Get all AwaitedActions. This call should be avoided as much as possible.
fn get_all_awaited_actions(
&self,
) -> impl Future<Output = impl Stream<Item = Result<Self::Subscriber, Error>> + Send> + Send;
) -> impl Future<
Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
> + Send;

/// Get the AwaitedAction by the operation id.
fn get_by_operation_id(
Expand All @@ -164,7 +166,9 @@ pub trait AwaitedActionDb: Send + Sync + MetricsComponent + Unpin + 'static {
start: Bound<SortedAwaitedAction>,
end: Bound<SortedAwaitedAction>,
desc: bool,
) -> impl Future<Output = impl Stream<Item = Result<Self::Subscriber, Error>> + Send> + Send;
) -> impl Future<
Output = Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error>,
> + Send;

/// Process a change changed AwaitedAction and notify any listeners.
fn update_awaited_action(
Expand Down
103 changes: 57 additions & 46 deletions nativelink-scheduler/src/memory_awaited_action_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ where
Ok(awaited_action)
}

fn borrow(&self) -> AwaitedAction {
async fn borrow(&self) -> Result<AwaitedAction, Error> {
let mut awaited_action = self.awaited_action_rx.borrow().clone();
if let Some(client_info) = self.client_info.as_ref() {
let mut state = awaited_action.state().as_ref().clone();
state.client_operation_id = client_info.client_operation_id.clone();
awaited_action.set_state(Arc::new(state), None);
}
awaited_action
Ok(awaited_action)
}
}

Expand Down Expand Up @@ -504,22 +504,25 @@ impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbI
&'a self,
state: SortedAwaitedActionState,
range: impl RangeBounds<SortedAwaitedAction> + 'b,
) -> impl DoubleEndedIterator<
Item = Result<
(
&'a SortedAwaitedAction,
MemoryAwaitedActionSubscriber<I, NowFn>,
),
Error,
>,
> + 'a {
) -> Result<
impl DoubleEndedIterator<
Item = Result<
(
&'a SortedAwaitedAction,
MemoryAwaitedActionSubscriber<I, NowFn>,
),
Error,
>,
> + 'a,
Error,
> {
let btree = match state {
SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check,
SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued,
SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing,
SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed,
};
btree.range(range).map(|sorted_awaited_action| {
Ok(btree.range(range).map(|sorted_awaited_action| {
let operation_id = &sorted_awaited_action.operation_id;
self.get_by_operation_id(operation_id)
.ok_or_else(|| {
Expand All @@ -530,7 +533,7 @@ impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync> AwaitedActionDbI
)
})
.map(|subscriber| (sorted_awaited_action, subscriber))
})
}))
}

fn process_state_changes_for_hash_key_map(
Expand Down Expand Up @@ -878,8 +881,10 @@ impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> Awaite
.await
}

async fn get_all_awaited_actions(&self) -> impl Stream<Item = Result<Self::Subscriber, Error>> {
ChunkedStream::new(
async fn get_all_awaited_actions(
&self,
) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>>, Error> {
Ok(ChunkedStream::new(
Bound::Unbounded,
Bound::Unbounded,
move |start, end, mut output| async move {
Expand All @@ -896,7 +901,7 @@ impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> Awaite
Ok(maybe_new_start
.map(|new_start| ((Bound::Excluded(new_start.clone()), end), output)))
},
)
))
}

async fn get_by_operation_id(
Expand All @@ -912,38 +917,44 @@ impl<I: InstantWrapper, NowFn: Fn() -> I + Clone + Send + Sync + 'static> Awaite
start: Bound<SortedAwaitedAction>,
end: Bound<SortedAwaitedAction>,
desc: bool,
) -> impl Stream<Item = Result<Self::Subscriber, Error>> + Send {
ChunkedStream::new(start, end, move |start, end, mut output| async move {
let inner = self.inner.lock().await;
let mut done = true;
let mut new_start = start.as_ref();
let mut new_end = end.as_ref();

let iterator = inner.get_range_of_actions(state, (start.as_ref(), end.as_ref()));
// TODO(allada) This should probably use the `.left()/right()` pattern,
// but that doesn't exist in the std or any libraries we use.
if desc {
for result in iterator.rev() {
let (sorted_awaited_action, item) =
result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
output.push_back(item);
new_end = Bound::Excluded(sorted_awaited_action);
done = false;
) -> Result<impl Stream<Item = Result<Self::Subscriber, Error>> + Send, Error> {
Ok(ChunkedStream::new(
start,
end,
move |start, end, mut output| async move {
let inner = self.inner.lock().await;
let mut done = true;
let mut new_start = start.as_ref();
let mut new_end = end.as_ref();

let iterator = inner
.get_range_of_actions(state, (start.as_ref(), end.as_ref()))
.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
// TODO(allada) This should probably use the `.left()/right()` pattern,
// but that doesn't exist in the std or any libraries we use.
if desc {
for result in iterator.rev() {
let (sorted_awaited_action, item) =
result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
output.push_back(item);
new_end = Bound::Excluded(sorted_awaited_action);
done = false;
}
} else {
for result in iterator {
let (sorted_awaited_action, item) =
result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
output.push_back(item);
new_start = Bound::Excluded(sorted_awaited_action);
done = false;
}
}
} else {
for result in iterator {
let (sorted_awaited_action, item) =
result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?;
output.push_back(item);
new_start = Bound::Excluded(sorted_awaited_action);
done = false;
if done {
return Ok(None);
}
}
if done {
return Ok(None);
}
Ok(Some(((new_start.cloned(), new_end.cloned()), output)))
})
Ok(Some(((new_start.cloned(), new_end.cloned()), output)))
},
))
}

async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> {
Expand Down
Loading

0 comments on commit 128aa7f

Please sign in to comment.