Skip to content

Commit

Permalink
Add KeepAlive updating to ApiWorkerScheduler (#1310)
Browse files Browse the repository at this point in the history
* Client action listeners will now timeout actions

Listeners of the actions will now flag actions as timedout instad of
just if the scheduler detects that the worker went offline.

This is to support distributed schedulers. Since any worker or scheduler
may now go offline we need to ensure the "owner" of the action is
actually the one who cares about it, in this case the client.

towards #359

* Add KeepAlive updating to ApiWorkerScheduler

This ensures that actions are periodically updated with with a keep
alive message, which will be heavily used with a distributed scheduling
model.

towards #359
  • Loading branch information
allada authored Sep 5, 2024
1 parent 96db0cb commit 37ebd58
Show file tree
Hide file tree
Showing 11 changed files with 585 additions and 138 deletions.
89 changes: 72 additions & 17 deletions nativelink-scheduler/src/api_worker_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ use nativelink_metric::{
group, MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
RootMetricsComponent,
};
use nativelink_util::action_messages::{ActionStage, OperationId, WorkerId};
use nativelink_util::operation_state_manager::WorkerStateManager;
use nativelink_util::action_messages::{OperationId, WorkerId};
use nativelink_util::operation_state_manager::{UpdateOperationType, WorkerStateManager};
use nativelink_util::platform_properties::PlatformProperties;
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use tokio::sync::mpsc::{self, UnboundedSender};
use tokio::sync::Notify;
use tonic::async_trait;
use tracing::{event, Level};
Expand Down Expand Up @@ -81,6 +84,8 @@ struct ApiWorkerSchedulerImpl {
allocation_strategy: WorkerAllocationStrategy,
/// A channel to notify the matching engine that the worker pool has changed.
worker_change_notify: Arc<Notify>,
/// A channel to notify that an operation is still alive.
operation_keep_alive_tx: UnboundedSender<(OperationId, WorkerId)>,
}

impl ApiWorkerSchedulerImpl {
Expand All @@ -103,6 +108,20 @@ impl ApiWorkerSchedulerImpl {
timestamp
);
worker.last_update_timestamp = timestamp;
for operation_id in worker.running_action_infos.keys() {
if self
.operation_keep_alive_tx
.send((operation_id.clone(), *worker_id))
.is_err()
{
event!(
Level::ERROR,
?operation_id,
?worker_id,
"OperationKeepAliveTx stream closed"
);
}
}
Ok(())
}

Expand Down Expand Up @@ -177,7 +196,7 @@ impl ApiWorkerSchedulerImpl {
&mut self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
let worker = self.workers.get_mut(worker_id).err_tip(|| {
format!("Worker {worker_id} does not exist in SimpleScheduler::update_action")
Expand All @@ -193,11 +212,21 @@ impl ApiWorkerSchedulerImpl {
.merge(self.immediate_evict_worker(worker_id, err).await);
}

let (is_finished, due_to_backpressure) = match &update {
UpdateOperationType::UpdateWithActionStage(action_stage) => {
(action_stage.is_finished(), false)
}
UpdateOperationType::KeepAlive => (false, false),
UpdateOperationType::UpdateWithError(err) => {
(true, err.code == Code::ResourceExhausted)
}
};

// Update the operation in the worker state manager.
{
let update_operation_res = self
.worker_state_manager
.update_operation(operation_id, worker_id, action_stage.clone())
.update_operation(operation_id, worker_id, update)
.await
.err_tip(|| "in update_operation on SimpleScheduler::update_action");
if let Err(err) = update_operation_res {
Expand All @@ -212,10 +241,6 @@ impl ApiWorkerSchedulerImpl {
}
}

// We are done if the action is not finished or there was an error.
let is_finished = action_stage
.as_ref()
.map_or_else(|_| true, |action_stage| action_stage.is_finished());
if !is_finished {
return Ok(());
}
Expand All @@ -227,9 +252,6 @@ impl ApiWorkerSchedulerImpl {
// Note: We need to run this before dealing with backpressure logic.
let complete_action_res = worker.complete_action(operation_id);

let due_to_backpressure = action_stage
.as_ref()
.map_or_else(|e| e.code == Code::ResourceExhausted, |_| false);
// Only pause if there's an action still waiting that will unpause.
if (was_paused || due_to_backpressure) && worker.has_actions() {
worker.is_paused = true;
Expand Down Expand Up @@ -296,7 +318,11 @@ impl ApiWorkerSchedulerImpl {
for (operation_id, _) in worker.running_action_infos.drain() {
result = result.merge(
self.worker_state_manager
.update_operation(&operation_id, worker_id, Err(err.clone()))
.update_operation(
&operation_id,
worker_id,
UpdateOperationType::UpdateWithError(err.clone()),
)
.await,
);
}
Expand All @@ -319,6 +345,7 @@ pub struct ApiWorkerScheduler {
help = "Timeout of how long to evict workers if no response in this given amount of time in seconds."
)]
worker_timeout_s: u64,
_operation_keep_alive_spawn: JoinHandleDropGuard<()>,
}

impl ApiWorkerScheduler {
Expand All @@ -329,15 +356,45 @@ impl ApiWorkerScheduler {
worker_change_notify: Arc<Notify>,
worker_timeout_s: u64,
) -> Arc<Self> {
let (operation_keep_alive_tx, mut operation_keep_alive_rx) = mpsc::unbounded_channel();
Arc::new(Self {
inner: Mutex::new(ApiWorkerSchedulerImpl {
workers: Workers(LruCache::unbounded()),
worker_state_manager,
worker_state_manager: worker_state_manager.clone(),
allocation_strategy,
worker_change_notify,
operation_keep_alive_tx,
}),
platform_property_manager,
worker_timeout_s,
_operation_keep_alive_spawn: spawn!(
"simple_scheduler_operation_keep_alive",
async move {
const RECV_MANY_LIMIT: usize = 256;
let mut messages = Vec::with_capacity(RECV_MANY_LIMIT);
loop {
messages.clear();
operation_keep_alive_rx
.recv_many(&mut messages, RECV_MANY_LIMIT)
.await;
if messages.is_empty() {
return; // Looks like our sender has been dropped.
}
for (operation_id, worker_id) in messages.drain(..) {
let update_operation_res = worker_state_manager
.update_operation(
&operation_id,
&worker_id,
UpdateOperationType::KeepAlive,
)
.await;
if let Err(err) = update_operation_res {
event!(Level::WARN, ?err, "Error while running worker_keep_alive_received, maybe job is done?");
}
}
}
}
),
})
}

Expand Down Expand Up @@ -408,12 +465,10 @@ impl WorkerScheduler for ApiWorkerScheduler {
&self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
let mut inner = self.inner.lock().await;
inner
.update_action(worker_id, operation_id, action_stage)
.await
inner.update_action(worker_id, operation_id, update).await
}

async fn worker_keep_alive_received(
Expand Down
41 changes: 17 additions & 24 deletions nativelink-scheduler/src/awaited_action_db/awaited_action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use bytes::Bytes;
use nativelink_error::{make_input_err, Error, ResultExt};
use nativelink_metric::{
MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
Expand Down Expand Up @@ -102,61 +101,55 @@ impl AwaitedAction {
}
}

pub fn version(&self) -> u64 {
pub(crate) fn version(&self) -> u64 {
self.version.0
}

pub fn increment_version(&mut self) {
pub(crate) fn increment_version(&mut self) {
self.version = AwaitedActionVersion(self.version.0 + 1);
}

pub fn action_info(&self) -> &Arc<ActionInfo> {
pub(crate) fn action_info(&self) -> &Arc<ActionInfo> {
&self.action_info
}

pub fn operation_id(&self) -> &OperationId {
pub(crate) fn operation_id(&self) -> &OperationId {
&self.operation_id
}

pub fn sort_key(&self) -> AwaitedActionSortKey {
pub(crate) fn sort_key(&self) -> AwaitedActionSortKey {
self.sort_key
}

pub fn state(&self) -> &Arc<ActionState> {
pub(crate) fn state(&self) -> &Arc<ActionState> {
&self.state
}

pub fn worker_id(&self) -> Option<WorkerId> {
pub(crate) fn worker_id(&self) -> Option<WorkerId> {
self.worker_id
}

pub fn last_worker_updated_timestamp(&self) -> SystemTime {
pub(crate) fn last_worker_updated_timestamp(&self) -> SystemTime {
self.last_worker_updated_timestamp
}

pub(crate) fn keep_alive(&mut self, now: SystemTime) {
self.last_worker_updated_timestamp = now;
}

/// Sets the worker id that is currently processing this action.
pub fn set_worker_id(&mut self, new_maybe_worker_id: Option<WorkerId>) {
pub(crate) fn set_worker_id(&mut self, new_maybe_worker_id: Option<WorkerId>, now: SystemTime) {
if self.worker_id != new_maybe_worker_id {
self.worker_id = new_maybe_worker_id;
self.last_worker_updated_timestamp = SystemTime::now();
self.keep_alive(now);
}
}

/// Sets the current state of the action and notifies subscribers.
/// Returns true if the state was set, false if there are no subscribers.
pub fn set_state(&mut self, mut state: Arc<ActionState>) {
pub(crate) fn set_state(&mut self, mut state: Arc<ActionState>, now: SystemTime) {
std::mem::swap(&mut self.state, &mut state);
self.last_worker_updated_timestamp = SystemTime::now();
}
}

impl TryInto<bytes::Bytes> for AwaitedAction {
type Error = Error;
fn try_into(self) -> Result<Bytes, Self::Error> {
serde_json::to_string(&self)
.map(Bytes::from)
.map_err(|e| make_input_err!("{}", e.to_string()))
.err_tip(|| "In AwaitedAction::TryInto::Bytes")
self.keep_alive(now);
}
}

Expand Down Expand Up @@ -216,7 +209,7 @@ impl AwaitedActionSortKey {
Self::new(priority, timestamp)
}

pub fn as_u64(&self) -> u64 {
pub(crate) fn as_u64(&self) -> u64 {
self.0
}
}
Expand Down
10 changes: 0 additions & 10 deletions nativelink-scheduler/src/awaited_action_db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,6 @@ impl TryFrom<ActionStage> for SortedAwaitedActionState {
}
}

impl std::fmt::Display for SortedAwaitedActionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SortedAwaitedActionState::CacheCheck => write!(f, "CacheCheck"),
SortedAwaitedActionState::Queued => write!(f, "Queued"),
SortedAwaitedActionState::Executing => write!(f, "Executing"),
SortedAwaitedActionState::Completed => write!(f, "Completed"),
}
}
}
/// A struct pointing to an AwaitedAction that can be sorted.
#[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)]
pub struct SortedAwaitedAction {
Expand Down
24 changes: 17 additions & 7 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
// limitations under the License.

use std::sync::Arc;
use std::time::SystemTime;

use async_trait::async_trait;
use futures::Future;
use nativelink_error::{Code, Error, ResultExt};
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
use nativelink_util::action_messages::{
ActionInfo, ActionStage, ActionState, OperationId, WorkerId,
};
use nativelink_util::action_messages::{ActionInfo, ActionState, OperationId, WorkerId};
use nativelink_util::instant_wrapper::InstantWrapper;
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
use nativelink_util::operation_state_manager::{
ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager,
OperationFilter, OperationStageFlags, OrderDirection,
OperationFilter, OperationStageFlags, OrderDirection, UpdateOperationType,
};
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
Expand Down Expand Up @@ -295,18 +295,22 @@ impl SimpleScheduler {
tokio::time::sleep(Duration::from_millis(1))
},
task_change_notify,
SystemTime::now,
)
}

pub fn new_with_callback<
Fut: Future<Output = ()> + Send,
F: Fn() -> Fut + Send + Sync + 'static,
A: AwaitedActionDb,
I: InstantWrapper,
NowFn: Fn() -> I + Clone + Send + Unpin + Sync + 'static,
>(
scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler,
awaited_action_db: A,
on_matching_engine_run: F,
task_change_notify: Arc<Notify>,
now_fn: NowFn,
) -> (Arc<Self>, Arc<dyn WorkerScheduler>) {
let platform_property_manager = Arc::new(PlatformPropertyManager::new(
scheduler_cfg
Expand All @@ -326,7 +330,13 @@ impl SimpleScheduler {
}

let worker_change_notify = Arc::new(Notify::new());
let state_manager = SimpleSchedulerStateManager::new(max_job_retries, awaited_action_db);
let state_manager = SimpleSchedulerStateManager::new(
max_job_retries,
// TODO(allada) This should probably have its own config.
Duration::from_secs(worker_timeout_s),
awaited_action_db,
now_fn,
);

let worker_scheduler = ApiWorkerScheduler::new(
state_manager.clone(),
Expand Down Expand Up @@ -426,10 +436,10 @@ impl WorkerScheduler for SimpleScheduler {
&self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
self.worker_scheduler
.update_action(worker_id, operation_id, action_stage)
.update_action(worker_id, operation_id, update)
.await
}

Expand Down
Loading

0 comments on commit 37ebd58

Please sign in to comment.