Skip to content

Commit

Permalink
Add KeepAlive updating to ApiWorkerScheduler
Browse files Browse the repository at this point in the history
This ensures that actions are periodically updated with with a keep
alive message, which will be heavily used with a distributed scheduling
model.

towards #359
  • Loading branch information
allada committed Sep 2, 2024
1 parent 1765e27 commit 3c9448a
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 99 deletions.
89 changes: 72 additions & 17 deletions nativelink-scheduler/src/api_worker_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ use nativelink_metric::{
group, MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
RootMetricsComponent,
};
use nativelink_util::action_messages::{ActionStage, OperationId, WorkerId};
use nativelink_util::operation_state_manager::WorkerStateManager;
use nativelink_util::action_messages::{OperationId, WorkerId};
use nativelink_util::operation_state_manager::{UpdateOperationType, WorkerStateManager};
use nativelink_util::platform_properties::PlatformProperties;
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use tokio::sync::mpsc::{self, UnboundedSender};
use tokio::sync::Notify;
use tonic::async_trait;
use tracing::{event, Level};
Expand Down Expand Up @@ -81,6 +84,8 @@ struct ApiWorkerSchedulerImpl {
allocation_strategy: WorkerAllocationStrategy,
/// A channel to notify the matching engine that the worker pool has changed.
worker_change_notify: Arc<Notify>,
/// A channel to notify that an operation is still alive.
operation_keep_alive_tx: UnboundedSender<(OperationId, WorkerId)>,
}

impl ApiWorkerSchedulerImpl {
Expand All @@ -103,6 +108,20 @@ impl ApiWorkerSchedulerImpl {
timestamp
);
worker.last_update_timestamp = timestamp;
for operation_id in worker.running_action_infos.keys() {
if self
.operation_keep_alive_tx
.send((operation_id.clone(), *worker_id))
.is_err()
{
event!(
Level::ERROR,
?operation_id,
?worker_id,
"OperationKeepAliveTx stream closed"
);
}
}
Ok(())
}

Expand Down Expand Up @@ -177,7 +196,7 @@ impl ApiWorkerSchedulerImpl {
&mut self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
let worker = self.workers.get_mut(worker_id).err_tip(|| {
format!("Worker {worker_id} does not exist in SimpleScheduler::update_action")
Expand All @@ -193,11 +212,21 @@ impl ApiWorkerSchedulerImpl {
.merge(self.immediate_evict_worker(worker_id, err).await);
}

let (is_finished, due_to_backpressure) = match &update {
UpdateOperationType::UpdateWithActionStage(action_stage) => {
(action_stage.is_finished(), false)
}
UpdateOperationType::KeepAlive => (false, false),
UpdateOperationType::UpdateWithError(err) => {
(true, err.code == Code::ResourceExhausted)
}
};

// Update the operation in the worker state manager.
{
let update_operation_res = self
.worker_state_manager
.update_operation(operation_id, worker_id, action_stage.clone())
.update_operation(operation_id, worker_id, update)
.await
.err_tip(|| "in update_operation on SimpleScheduler::update_action");
if let Err(err) = update_operation_res {
Expand All @@ -212,10 +241,6 @@ impl ApiWorkerSchedulerImpl {
}
}

// We are done if the action is not finished or there was an error.
let is_finished = action_stage
.as_ref()
.map_or_else(|_| true, |action_stage| action_stage.is_finished());
if !is_finished {
return Ok(());
}
Expand All @@ -227,9 +252,6 @@ impl ApiWorkerSchedulerImpl {
// Note: We need to run this before dealing with backpressure logic.
let complete_action_res = worker.complete_action(operation_id);

let due_to_backpressure = action_stage
.as_ref()
.map_or_else(|e| e.code == Code::ResourceExhausted, |_| false);
// Only pause if there's an action still waiting that will unpause.
if (was_paused || due_to_backpressure) && worker.has_actions() {
worker.is_paused = true;
Expand Down Expand Up @@ -296,7 +318,11 @@ impl ApiWorkerSchedulerImpl {
for (operation_id, _) in worker.running_action_infos.drain() {
result = result.merge(
self.worker_state_manager
.update_operation(&operation_id, worker_id, Err(err.clone()))
.update_operation(
&operation_id,
worker_id,
UpdateOperationType::UpdateWithError(err.clone()),
)
.await,
);
}
Expand All @@ -319,6 +345,7 @@ pub struct ApiWorkerScheduler {
help = "Timeout of how long to evict workers if no response in this given amount of time in seconds."
)]
worker_timeout_s: u64,
_operation_keep_alive_spawn: JoinHandleDropGuard<()>,
}

impl ApiWorkerScheduler {
Expand All @@ -329,15 +356,45 @@ impl ApiWorkerScheduler {
worker_change_notify: Arc<Notify>,
worker_timeout_s: u64,
) -> Arc<Self> {
let (operation_keep_alive_tx, mut operation_keep_alive_rx) = mpsc::unbounded_channel();
Arc::new(Self {
inner: Mutex::new(ApiWorkerSchedulerImpl {
workers: Workers(LruCache::unbounded()),
worker_state_manager,
worker_state_manager: worker_state_manager.clone(),
allocation_strategy,
worker_change_notify,
operation_keep_alive_tx,
}),
platform_property_manager,
worker_timeout_s,
_operation_keep_alive_spawn: spawn!(
"simple_scheduler_operation_keep_alive",
async move {
const RECV_MANY_LIMIT: usize = 256;
let mut messages = Vec::with_capacity(RECV_MANY_LIMIT);
loop {
messages.clear();
operation_keep_alive_rx
.recv_many(&mut messages, RECV_MANY_LIMIT)
.await;
if messages.is_empty() {
return; // Looks like our sender has been dropped.
}
for (operation_id, worker_id) in messages.drain(..) {
let update_operation_res = worker_state_manager
.update_operation(
&operation_id,
&worker_id,
UpdateOperationType::KeepAlive,
)
.await;
if let Err(err) = update_operation_res {
event!(Level::WARN, ?err, "Error while running worker_keep_alive_received, maybe job is done?");
}
}
}
}
),
})
}

Expand Down Expand Up @@ -408,12 +465,10 @@ impl WorkerScheduler for ApiWorkerScheduler {
&self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
let mut inner = self.inner.lock().await;
inner
.update_action(worker_id, operation_id, action_stage)
.await
inner.update_action(worker_id, operation_id, update).await
}

async fn worker_keep_alive_received(
Expand Down
41 changes: 17 additions & 24 deletions nativelink-scheduler/src/awaited_action_db/awaited_action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use bytes::Bytes;
use nativelink_error::{make_input_err, Error, ResultExt};
use nativelink_metric::{
MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
Expand Down Expand Up @@ -102,61 +101,55 @@ impl AwaitedAction {
}
}

pub fn version(&self) -> u64 {
pub(crate) fn version(&self) -> u64 {
self.version.0
}

pub fn increment_version(&mut self) {
pub(crate) fn increment_version(&mut self) {
self.version = AwaitedActionVersion(self.version.0 + 1);
}

pub fn action_info(&self) -> &Arc<ActionInfo> {
pub(crate) fn action_info(&self) -> &Arc<ActionInfo> {
&self.action_info
}

pub fn operation_id(&self) -> &OperationId {
pub(crate) fn operation_id(&self) -> &OperationId {
&self.operation_id
}

pub fn sort_key(&self) -> AwaitedActionSortKey {
pub(crate) fn sort_key(&self) -> AwaitedActionSortKey {
self.sort_key
}

pub fn state(&self) -> &Arc<ActionState> {
pub(crate) fn state(&self) -> &Arc<ActionState> {
&self.state
}

pub fn worker_id(&self) -> Option<WorkerId> {
pub(crate) fn worker_id(&self) -> Option<WorkerId> {
self.worker_id
}

pub fn last_worker_updated_timestamp(&self) -> SystemTime {
pub(crate) fn last_worker_updated_timestamp(&self) -> SystemTime {
self.last_worker_updated_timestamp
}

pub(crate) fn keep_alive(&mut self, now: SystemTime) {
self.last_worker_updated_timestamp = now;
}

/// Sets the worker id that is currently processing this action.
pub fn set_worker_id(&mut self, new_maybe_worker_id: Option<WorkerId>) {
pub(crate) fn set_worker_id(&mut self, new_maybe_worker_id: Option<WorkerId>, now: SystemTime) {
if self.worker_id != new_maybe_worker_id {
self.worker_id = new_maybe_worker_id;
self.last_worker_updated_timestamp = SystemTime::now();
self.keep_alive(now);
}
}

/// Sets the current state of the action and notifies subscribers.
/// Returns true if the state was set, false if there are no subscribers.
pub fn set_state(&mut self, mut state: Arc<ActionState>) {
pub(crate) fn set_state(&mut self, mut state: Arc<ActionState>, now: SystemTime) {
std::mem::swap(&mut self.state, &mut state);
self.last_worker_updated_timestamp = SystemTime::now();
}
}

impl TryInto<bytes::Bytes> for AwaitedAction {
type Error = Error;
fn try_into(self) -> Result<Bytes, Self::Error> {
serde_json::to_string(&self)
.map(Bytes::from)
.map_err(|e| make_input_err!("{}", e.to_string()))
.err_tip(|| "In AwaitedAction::TryInto::Bytes")
self.keep_alive(now);
}
}

Expand Down Expand Up @@ -216,7 +209,7 @@ impl AwaitedActionSortKey {
Self::new(priority, timestamp)
}

pub fn as_u64(&self) -> u64 {
pub(crate) fn as_u64(&self) -> u64 {
self.0
}
}
Expand Down
10 changes: 0 additions & 10 deletions nativelink-scheduler/src/awaited_action_db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,6 @@ impl TryFrom<ActionStage> for SortedAwaitedActionState {
}
}

impl std::fmt::Display for SortedAwaitedActionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SortedAwaitedActionState::CacheCheck => write!(f, "CacheCheck"),
SortedAwaitedActionState::Queued => write!(f, "Queued"),
SortedAwaitedActionState::Executing => write!(f, "Executing"),
SortedAwaitedActionState::Completed => write!(f, "Completed"),
}
}
}
/// A struct pointing to an AwaitedAction that can be sorted.
#[derive(Debug, Clone, Serialize, Deserialize, MetricsComponent)]
pub struct SortedAwaitedAction {
Expand Down
10 changes: 4 additions & 6 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ use async_trait::async_trait;
use futures::Future;
use nativelink_error::{Code, Error, ResultExt};
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
use nativelink_util::action_messages::{
ActionInfo, ActionStage, ActionState, OperationId, WorkerId,
};
use nativelink_util::action_messages::{ActionInfo, ActionState, OperationId, WorkerId};
use nativelink_util::instant_wrapper::InstantWrapper;
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
use nativelink_util::operation_state_manager::{
ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager,
OperationFilter, OperationStageFlags, OrderDirection,
OperationFilter, OperationStageFlags, OrderDirection, UpdateOperationType,
};
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
Expand Down Expand Up @@ -438,10 +436,10 @@ impl WorkerScheduler for SimpleScheduler {
&self,
worker_id: &WorkerId,
operation_id: &OperationId,
action_stage: Result<ActionStage, Error>,
update: UpdateOperationType,
) -> Result<(), Error> {
self.worker_scheduler
.update_action(worker_id, operation_id, action_stage)
.update_action(worker_id, operation_id, update)
.await
}

Expand Down
Loading

0 comments on commit 3c9448a

Please sign in to comment.