diff --git a/crates/kernel/src/tasks/scheduler.rs b/crates/kernel/src/tasks/scheduler.rs index 5bb4cf29..bba18e04 100644 --- a/crates/kernel/src/tasks/scheduler.rs +++ b/crates/kernel/src/tasks/scheduler.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use std::fs::File; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use bincode::{Decode, Encode}; @@ -73,8 +73,12 @@ pub struct Scheduler { running: Arc, database: Arc, next_task_id: AtomicUsize, - tasks: Mutex>, - input_requests: Mutex>, + inner: RwLock, +} + +struct Inner { + input_requests: HashMap, + tasks: HashMap, } #[derive(Clone, Copy, Debug, Eq, PartialEq, Decode, Encode)] @@ -166,12 +170,15 @@ impl Scheduler { pub fn new(database: Arc, config: Config) -> Self { let config = Arc::new(config); let (control_sender, control_receiver) = crossbeam_channel::unbounded(); + let inner = Inner { + input_requests: Default::default(), + tasks: Default::default(), + }; Self { running: Arc::new(AtomicBool::new(false)), database, next_task_id: Default::default(), - tasks: Default::default(), - input_requests: Default::default(), + inner: RwLock::new(inner), config, control_sender, control_receiver, @@ -225,14 +232,12 @@ impl Scheduler { // the given input, clearing the input request out. trace!(?input_request_id, ?input, "Received input for task"); - let mut input_requests = self.input_requests.lock().unwrap(); - let Some(task_id) = input_requests.get(&input_request_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task_id) = inner.input_requests.get(&input_request_id) else { return Err(InputRequestNotFound(input_request_id.as_u128())); }; let task_id = *task_id; - - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(?task_id, ?input_request_id, "Input received for dead task"); return Err(TaskNotFound(task_id)); }; @@ -256,7 +261,7 @@ impl Scheduler { input, )) .map_err(|_| CouldNotStartTask)?; - input_requests.remove(&input_request_id); + inner.input_requests.remove(&input_request_id); Ok(()) } @@ -436,14 +441,14 @@ impl Scheduler { pub fn abort_player_tasks(&self, player: Objid) -> Result<(), SchedulerError> { let mut to_abort = Vec::new(); - let mut tasks = self.tasks.lock().unwrap(); - for (task_id, task_ref) in tasks.iter() { + let mut inner = self.inner.write().unwrap(); + for (task_id, task_ref) in inner.tasks.iter() { if task_ref.player == player { to_abort.push(*task_id); } } for task_id in to_abort { - let task = tasks.get_mut(&task_id).expect("Corrupt task list"); + let task = inner.tasks.get_mut(&task_id).expect("Corrupt task list"); let tcs = task.task_control_sender.clone(); if let Err(e) = tcs.send(TaskControlMsg::Abort) { warn!(task_id, error = ?e, "Could not send abort for task. Dead?"); @@ -456,9 +461,9 @@ impl Scheduler { /// Request information on all tasks known to the scheduler. pub fn tasks(&self) -> Result, SchedulerError> { + let inner = self.inner.read().unwrap(); let mut tasks = Vec::new(); - let task_lock = self.tasks.lock().unwrap(); - for (task_id, task) in task_lock.iter() { + for (task_id, task) in inner.tasks.iter() { trace!(task_id, "Requesting task description"); let (t_send, t_reply) = oneshot::channel(); let tcs = task.task_control_sender.clone(); @@ -484,8 +489,8 @@ impl Scheduler { warn!("Issuing clean shutdown..."); { // Send shut down to all the tasks. - let tasks = self.tasks.lock().unwrap(); - for task in tasks.values() { + let inner = self.inner.read().unwrap(); + for task in inner.tasks.values() { let tcs = task.task_control_sender.clone(); if let Err(e) = tcs.send(TaskControlMsg::Abort) { warn!(task_id = task.task_id, error = ?e, "Could not send abort for task. Already dead?"); @@ -497,8 +502,11 @@ impl Scheduler { // Then spin until they're all done. loop { - if self.tasks.lock().unwrap().is_empty() { - break; + { + let inner = self.inner.read().unwrap(); + if inner.tasks.is_empty() { + break; + } } yield_now(); } @@ -510,8 +518,8 @@ impl Scheduler { } pub fn abort_task(&self, id: TaskId) -> Result<(), SchedulerError> { - let mut tasks = self.tasks.lock().unwrap(); - let task = tasks.get_mut(&id).ok_or(TaskNotFound(id))?; + let mut inner = self.inner.write().unwrap(); + let task = inner.tasks.get_mut(&id).ok_or(TaskNotFound(id))?; let tcs = task.task_control_sender.clone(); if let Err(e) = tcs.send(TaskControlMsg::Abort) { error!(error = ?e, "Could not send abort message to task on its channel. Already dead?"); @@ -538,8 +546,8 @@ impl Scheduler { let mut to_wake = Vec::new(); let mut to_prune = Vec::new(); { - let tasks = self.tasks.lock().unwrap(); - for (task_id, task) in tasks.iter() { + let inner = self.inner.read().unwrap(); + for (task_id, task) in inner.tasks.iter() { if !task.task_control_sender.is_ready() { warn!( task_id, @@ -586,8 +594,8 @@ impl Scheduler { match msg { SchedulerControlMsg::TaskSuccess(value) => { // Commit the session. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for success"); return None; }; @@ -634,8 +642,8 @@ impl Scheduler { warn!(?task_id, "Task cancelled"); // Rollback the session. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for abort"); return None; }; @@ -670,9 +678,9 @@ impl Scheduler { } }; - // Commit the session. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + // Commit the session + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for abort"); return None; }; @@ -691,8 +699,8 @@ impl Scheduler { SchedulerControlMsg::TaskException(exception) => { warn!(?task_id, finally_reason = ?exception, "Task threw exception"); - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for abort"); return None; }; @@ -725,8 +733,8 @@ impl Scheduler { // Task has requested a fork. Dispatch it and reply with the new task id. // Gotta dump this out til we exit the loop tho, since self.tasks is already // borrowed here. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for fork request"); return None; }; @@ -741,8 +749,8 @@ impl Scheduler { // Task is suspended. The resume time (if any) is the system time at which // the scheduler should try to wake us up. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for suspend request"); return None; }; @@ -768,8 +776,8 @@ impl Scheduler { let input_request_id = Uuid::new_v4(); { - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for input request"); return None; }; @@ -791,9 +799,8 @@ impl Scheduler { )); }; task.waiting_input = Some(input_request_id); + inner.input_requests.insert(input_request_id, task_id); } - let mut input_requests = self.input_requests.lock().unwrap(); - input_requests.insert(input_request_id, task_id); trace!(?task_id, "Task suspended waiting for input"); None } @@ -835,8 +842,8 @@ impl Scheduler { } SchedulerControlMsg::Notify { player, event } => { // Task is asking to notify a player. - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for notify request"); return None; }; @@ -855,8 +862,8 @@ impl Scheduler { Ok(_) => v_string("Scheduler stopping.".to_string()), Err(e) => v_string(format!("Shutdown failed: {e}")), }; - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task_id, "Task not found for notify request"); return None; }; @@ -947,8 +954,8 @@ impl Scheduler { )?; let task_id = task_handle.task_id(); - let mut tasks = self.tasks.lock().unwrap(); - let Some(task_ref) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task_ref) = inner.tasks.get_mut(&task_id) else { return Err(TaskNotFound(task_id)); }; @@ -1015,8 +1022,8 @@ impl Scheduler { } fn process_notification(&self, task_id: TaskId, result: TaskResult) { - let mut tasks = self.tasks.lock().unwrap(); - let Some(task_control) = tasks.remove(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task_control) = inner.tasks.remove(&task_id) else { // Missing task, must have ended already. This is odd though? So we'll warn. warn!(task_id, "Task not found for notification, ignoring"); return; @@ -1043,9 +1050,9 @@ impl Scheduler { trace!(?to_wake, "Waking up tasks..."); - let mut tasks = self.tasks.lock().unwrap(); + let mut inner = self.inner.write().unwrap(); for task_id in to_wake { - let task = tasks.get_mut(task_id).unwrap(); + let task = inner.tasks.get_mut(task_id).unwrap(); task.suspended = false; let world_state_source = self @@ -1103,8 +1110,8 @@ impl Scheduler { task = requesting_task_id, "Task requesting task descriptions" ); - let tasks_lock = self.tasks.lock().unwrap(); - for (task_id, task) in tasks_lock.iter() { + let inner = self.inner.read().unwrap(); + for (task_id, task) in inner.tasks.iter() { // Tasks not in suspended state shouldn't be added. if !task.suspended { continue; @@ -1165,8 +1172,8 @@ impl Scheduler { return vec![]; } - let tasks = self.tasks.lock().unwrap(); - let victim_task = match tasks.get(&victim_task_id) { + let inner = self.inner.read().unwrap(); + let victim_task = match inner.tasks.get(&victim_task_id) { Some(victim_task) => victim_task, None => { result_sender @@ -1225,8 +1232,8 @@ impl Scheduler { } // Task does not exist. - let mut tasks = self.tasks.lock().unwrap(); - let queued_task = match tasks.get_mut(&queued_task_id) { + let mut inner = self.inner.write().unwrap(); + let queued_task = match inner.tasks.get_mut(&queued_task_id) { Some(queued_task) => queued_task, None => { result_sender @@ -1281,8 +1288,8 @@ impl Scheduler { } fn process_retry_request(&self, task_id: TaskId) -> Option { - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&task_id) else { warn!(task = task_id, "Retrying task not found"); return None; }; @@ -1306,8 +1313,8 @@ impl Scheduler { } fn process_disconnect(&self, disconnect_task_id: TaskId, player: Objid) { - let mut tasks = self.tasks.lock().unwrap(); - let Some(task) = tasks.get_mut(&disconnect_task_id) else { + let mut inner = self.inner.write().unwrap(); + let Some(task) = inner.tasks.get_mut(&disconnect_task_id) else { warn!(task = disconnect_task_id, "Disconnecting task not found"); return; }; @@ -1320,8 +1327,7 @@ impl Scheduler { // Then abort all of their still-living forked tasks (that weren't the disconnect // task, we need to let that run to completion for sanity's sake.) - let tasks = self.tasks.lock().unwrap(); - for (task_id, task) in tasks.iter() { + for (task_id, task) in inner.tasks.iter() { if *task_id == disconnect_task_id { continue; } @@ -1342,10 +1348,10 @@ impl Scheduler { } fn process_task_removals(&self, to_remove: &[TaskId]) { - let mut tasks = self.tasks.lock().unwrap(); + let mut inner = self.inner.write().unwrap(); for task_id in to_remove { trace!(task = task_id, "Task removed"); - tasks.remove(task_id); + inner.tasks.remove(task_id); } } @@ -1389,8 +1395,8 @@ impl Scheduler { resume_time: None, result_sender: Mutex::new(Some(sender)), }; - let mut tasks = self.tasks.lock().unwrap(); - tasks.insert(task_id, task_control); + let mut inner = self.inner.write().unwrap(); + inner.tasks.insert(task_id, task_control); // Footgun warning: ALWAYS `self.tasks.insert` before spawning the task thread!