Skip to content

Commit

Permalink
Scheduler: Consolidate input_requests and tasks behind one lock
Browse files Browse the repository at this point in the history
Puts the Task list and the set of pending input requests in one single
`Inner` in an `RwLock` instead of two separate locks.

There was at least the potential for deadlock here.
  • Loading branch information
rdaum committed Jun 24, 2024
1 parent 8b76e02 commit aef4c02
Showing 1 changed file with 73 additions and 67 deletions.
140 changes: 73 additions & 67 deletions crates/kernel/src/tasks/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::HashMap;
use std::fs::File;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};

use bincode::{Decode, Encode};
Expand Down Expand Up @@ -73,8 +73,12 @@ pub struct Scheduler {
running: Arc<AtomicBool>,
database: Arc<dyn Database + Send + Sync>,
next_task_id: AtomicUsize,
tasks: Mutex<HashMap<TaskId, TaskControl>>,
input_requests: Mutex<HashMap<Uuid, TaskId>>,
inner: RwLock<Inner>,
}

struct Inner {
input_requests: HashMap<Uuid, TaskId>,
tasks: HashMap<TaskId, TaskControl>,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, Decode, Encode)]
Expand Down Expand Up @@ -166,12 +170,15 @@ impl Scheduler {
pub fn new(database: Arc<dyn Database + Send + Sync>, config: Config) -> Self {
let config = Arc::new(config);
let (control_sender, control_receiver) = crossbeam_channel::unbounded();
let inner = Inner {
input_requests: Default::default(),
tasks: Default::default(),
};
Self {
running: Arc::new(AtomicBool::new(false)),
database,
next_task_id: Default::default(),
tasks: Default::default(),
input_requests: Default::default(),
inner: RwLock::new(inner),
config,
control_sender,
control_receiver,
Expand Down Expand Up @@ -225,14 +232,12 @@ impl Scheduler {
// the given input, clearing the input request out.
trace!(?input_request_id, ?input, "Received input for task");

let mut input_requests = self.input_requests.lock().unwrap();
let Some(task_id) = input_requests.get(&input_request_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task_id) = inner.input_requests.get(&input_request_id) else {
return Err(InputRequestNotFound(input_request_id.as_u128()));
};
let task_id = *task_id;

let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(?task_id, ?input_request_id, "Input received for dead task");
return Err(TaskNotFound(task_id));
};
Expand All @@ -256,7 +261,7 @@ impl Scheduler {
input,
))
.map_err(|_| CouldNotStartTask)?;
input_requests.remove(&input_request_id);
inner.input_requests.remove(&input_request_id);

Ok(())
}
Expand Down Expand Up @@ -436,14 +441,14 @@ impl Scheduler {

pub fn abort_player_tasks(&self, player: Objid) -> Result<(), SchedulerError> {
let mut to_abort = Vec::new();
let mut tasks = self.tasks.lock().unwrap();
for (task_id, task_ref) in tasks.iter() {
let mut inner = self.inner.write().unwrap();
for (task_id, task_ref) in inner.tasks.iter() {
if task_ref.player == player {
to_abort.push(*task_id);
}
}
for task_id in to_abort {
let task = tasks.get_mut(&task_id).expect("Corrupt task list");
let task = inner.tasks.get_mut(&task_id).expect("Corrupt task list");
let tcs = task.task_control_sender.clone();
if let Err(e) = tcs.send(TaskControlMsg::Abort) {
warn!(task_id, error = ?e, "Could not send abort for task. Dead?");
Expand All @@ -456,9 +461,9 @@ impl Scheduler {

/// Request information on all tasks known to the scheduler.
pub fn tasks(&self) -> Result<Vec<TaskDescription>, SchedulerError> {
let inner = self.inner.read().unwrap();
let mut tasks = Vec::new();
let task_lock = self.tasks.lock().unwrap();
for (task_id, task) in task_lock.iter() {
for (task_id, task) in inner.tasks.iter() {
trace!(task_id, "Requesting task description");
let (t_send, t_reply) = oneshot::channel();
let tcs = task.task_control_sender.clone();
Expand All @@ -484,8 +489,8 @@ impl Scheduler {
warn!("Issuing clean shutdown...");
{
// Send shut down to all the tasks.
let tasks = self.tasks.lock().unwrap();
for task in tasks.values() {
let inner = self.inner.read().unwrap();
for task in inner.tasks.values() {
let tcs = task.task_control_sender.clone();
if let Err(e) = tcs.send(TaskControlMsg::Abort) {
warn!(task_id = task.task_id, error = ?e, "Could not send abort for task. Already dead?");
Expand All @@ -497,8 +502,11 @@ impl Scheduler {

// Then spin until they're all done.
loop {
if self.tasks.lock().unwrap().is_empty() {
break;
{
let inner = self.inner.read().unwrap();
if inner.tasks.is_empty() {
break;
}
}
yield_now();
}
Expand All @@ -510,8 +518,8 @@ impl Scheduler {
}

pub fn abort_task(&self, id: TaskId) -> Result<(), SchedulerError> {
let mut tasks = self.tasks.lock().unwrap();
let task = tasks.get_mut(&id).ok_or(TaskNotFound(id))?;
let mut inner = self.inner.write().unwrap();
let task = inner.tasks.get_mut(&id).ok_or(TaskNotFound(id))?;
let tcs = task.task_control_sender.clone();
if let Err(e) = tcs.send(TaskControlMsg::Abort) {
error!(error = ?e, "Could not send abort message to task on its channel. Already dead?");
Expand All @@ -538,8 +546,8 @@ impl Scheduler {
let mut to_wake = Vec::new();
let mut to_prune = Vec::new();
{
let tasks = self.tasks.lock().unwrap();
for (task_id, task) in tasks.iter() {
let inner = self.inner.read().unwrap();
for (task_id, task) in inner.tasks.iter() {
if !task.task_control_sender.is_ready() {
warn!(
task_id,
Expand Down Expand Up @@ -586,8 +594,8 @@ impl Scheduler {
match msg {
SchedulerControlMsg::TaskSuccess(value) => {
// Commit the session.
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for success");
return None;
};
Expand Down Expand Up @@ -634,8 +642,8 @@ impl Scheduler {
warn!(?task_id, "Task cancelled");

// Rollback the session.
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for abort");
return None;
};
Expand Down Expand Up @@ -670,9 +678,9 @@ impl Scheduler {
}
};

// Commit the session.
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
// Commit the session
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for abort");
return None;
};
Expand All @@ -691,8 +699,8 @@ impl Scheduler {
SchedulerControlMsg::TaskException(exception) => {
warn!(?task_id, finally_reason = ?exception, "Task threw exception");

let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for abort");
return None;
};
Expand Down Expand Up @@ -725,8 +733,8 @@ impl Scheduler {
// Task has requested a fork. Dispatch it and reply with the new task id.
// Gotta dump this out til we exit the loop tho, since self.tasks is already
// borrowed here.
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for fork request");
return None;
};
Expand All @@ -741,8 +749,8 @@ impl Scheduler {
// Task is suspended. The resume time (if any) is the system time at which
// the scheduler should try to wake us up.

let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for suspend request");
return None;
};
Expand All @@ -768,8 +776,8 @@ impl Scheduler {

let input_request_id = Uuid::new_v4();
{
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for input request");
return None;
};
Expand All @@ -791,9 +799,8 @@ impl Scheduler {
));
};
task.waiting_input = Some(input_request_id);
inner.input_requests.insert(input_request_id, task_id);
}
let mut input_requests = self.input_requests.lock().unwrap();
input_requests.insert(input_request_id, task_id);
trace!(?task_id, "Task suspended waiting for input");
None
}
Expand Down Expand Up @@ -835,8 +842,8 @@ impl Scheduler {
}
SchedulerControlMsg::Notify { player, event } => {
// Task is asking to notify a player.
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for notify request");
return None;
};
Expand All @@ -855,8 +862,8 @@ impl Scheduler {
Ok(_) => v_string("Scheduler stopping.".to_string()),
Err(e) => v_string(format!("Shutdown failed: {e}")),
};
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task_id, "Task not found for notify request");
return None;
};
Expand Down Expand Up @@ -947,8 +954,8 @@ impl Scheduler {
)?;

let task_id = task_handle.task_id();
let mut tasks = self.tasks.lock().unwrap();
let Some(task_ref) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task_ref) = inner.tasks.get_mut(&task_id) else {
return Err(TaskNotFound(task_id));
};

Expand Down Expand Up @@ -1015,8 +1022,8 @@ impl Scheduler {
}

fn process_notification(&self, task_id: TaskId, result: TaskResult) {
let mut tasks = self.tasks.lock().unwrap();
let Some(task_control) = tasks.remove(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task_control) = inner.tasks.remove(&task_id) else {
// Missing task, must have ended already. This is odd though? So we'll warn.
warn!(task_id, "Task not found for notification, ignoring");
return;
Expand All @@ -1043,9 +1050,9 @@ impl Scheduler {

trace!(?to_wake, "Waking up tasks...");

let mut tasks = self.tasks.lock().unwrap();
let mut inner = self.inner.write().unwrap();
for task_id in to_wake {
let task = tasks.get_mut(task_id).unwrap();
let task = inner.tasks.get_mut(task_id).unwrap();
task.suspended = false;

let world_state_source = self
Expand Down Expand Up @@ -1103,8 +1110,8 @@ impl Scheduler {
task = requesting_task_id,
"Task requesting task descriptions"
);
let tasks_lock = self.tasks.lock().unwrap();
for (task_id, task) in tasks_lock.iter() {
let inner = self.inner.read().unwrap();
for (task_id, task) in inner.tasks.iter() {
// Tasks not in suspended state shouldn't be added.
if !task.suspended {
continue;
Expand Down Expand Up @@ -1165,8 +1172,8 @@ impl Scheduler {
return vec![];
}

let tasks = self.tasks.lock().unwrap();
let victim_task = match tasks.get(&victim_task_id) {
let inner = self.inner.read().unwrap();
let victim_task = match inner.tasks.get(&victim_task_id) {
Some(victim_task) => victim_task,
None => {
result_sender
Expand Down Expand Up @@ -1225,8 +1232,8 @@ impl Scheduler {
}

// Task does not exist.
let mut tasks = self.tasks.lock().unwrap();
let queued_task = match tasks.get_mut(&queued_task_id) {
let mut inner = self.inner.write().unwrap();
let queued_task = match inner.tasks.get_mut(&queued_task_id) {
Some(queued_task) => queued_task,
None => {
result_sender
Expand Down Expand Up @@ -1281,8 +1288,8 @@ impl Scheduler {
}

fn process_retry_request(&self, task_id: TaskId) -> Option<TaskId> {
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&task_id) else {
warn!(task = task_id, "Retrying task not found");
return None;
};
Expand All @@ -1306,8 +1313,8 @@ impl Scheduler {
}

fn process_disconnect(&self, disconnect_task_id: TaskId, player: Objid) {
let mut tasks = self.tasks.lock().unwrap();
let Some(task) = tasks.get_mut(&disconnect_task_id) else {
let mut inner = self.inner.write().unwrap();
let Some(task) = inner.tasks.get_mut(&disconnect_task_id) else {
warn!(task = disconnect_task_id, "Disconnecting task not found");
return;
};
Expand All @@ -1320,8 +1327,7 @@ impl Scheduler {

// Then abort all of their still-living forked tasks (that weren't the disconnect
// task, we need to let that run to completion for sanity's sake.)
let tasks = self.tasks.lock().unwrap();
for (task_id, task) in tasks.iter() {
for (task_id, task) in inner.tasks.iter() {
if *task_id == disconnect_task_id {
continue;
}
Expand All @@ -1342,10 +1348,10 @@ impl Scheduler {
}

fn process_task_removals(&self, to_remove: &[TaskId]) {
let mut tasks = self.tasks.lock().unwrap();
let mut inner = self.inner.write().unwrap();
for task_id in to_remove {
trace!(task = task_id, "Task removed");
tasks.remove(task_id);
inner.tasks.remove(task_id);
}
}

Expand Down Expand Up @@ -1389,8 +1395,8 @@ impl Scheduler {
resume_time: None,
result_sender: Mutex::new(Some(sender)),
};
let mut tasks = self.tasks.lock().unwrap();
tasks.insert(task_id, task_control);
let mut inner = self.inner.write().unwrap();
inner.tasks.insert(task_id, task_control);

// Footgun warning: ALWAYS `self.tasks.insert` before spawning the task thread!

Expand Down

0 comments on commit aef4c02

Please sign in to comment.