diff --git a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs index 16ab5f789675d..0314337b8a066 100644 --- a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor}; -use bevy_utils::default; use bevy_utils::syncunsafecell::SyncUnsafeCell; #[cfg(feature = "trace")] use bevy_utils::tracing::{info_span, Instrument}; +use bevy_utils::{default, Duration, Instant}; use std::panic::AssertUnwindSafe; use async_channel::{Receiver, Sender}; @@ -23,6 +23,14 @@ use crate::{ use crate as bevy_ecs; +const EXECUTION_TIME_DECAY_RATE: f32 = 0.25; +const INLINE_THRESHOLD: Duration = Duration::from_micros(2); + +struct SystemCompletion { + system_index: usize, + duration: Duration, +} + /// A funky borrow split of [`SystemSchedule`] required by the [`MultiThreadedExecutor`]. struct SyncUnsafeSchedule<'a> { systems: &'a [SyncUnsafeCell], @@ -61,14 +69,26 @@ struct SystemTaskMetadata { is_send: bool, /// Is `true` if the system is exclusive. is_exclusive: bool, + /// Weighted average of the execution time of the system. + avg_execution_time: Option, +} + +impl SystemTaskMetadata { + fn update_exeuction_time(&mut self, mut execution_time: Duration) { + if let Some(avg_execution_time) = self.avg_execution_time { + execution_time = execution_time.mul_f32(EXECUTION_TIME_DECAY_RATE) + + avg_execution_time.mul_f32(1.0 - EXECUTION_TIME_DECAY_RATE); + } + self.avg_execution_time = Some(execution_time); + } } /// Runs the schedule using a thread pool. Non-conflicting systems can run in parallel. pub struct MultiThreadedExecutor { /// Sends system completion events. - sender: Sender, + sender: Sender, /// Receives system completion events. - receiver: Receiver, + receiver: Receiver, /// Metadata for scheduling and running system tasks. system_task_metadata: Vec, /// Union of the accesses of all currently running systems. @@ -99,6 +119,8 @@ pub struct MultiThreadedExecutor { unapplied_systems: FixedBitSet, /// Setting when true applies system buffers after all systems have run apply_final_buffers: bool, + // A queue of inlined systems that have yet to run. + // inlined_system_queue: Vec, } impl Default for MultiThreadedExecutor { @@ -136,6 +158,7 @@ impl SystemExecutor for MultiThreadedExecutor { dependents: schedule.system_dependents[index].clone(), is_send: schedule.systems[index].is_send(), is_exclusive: schedule.systems[index].is_exclusive(), + avg_execution_time: None, }); } @@ -186,15 +209,15 @@ impl SystemExecutor for MultiThreadedExecutor { if self.num_running_systems > 0 { // wait for systems to complete - let index = + let completion = self.receiver.recv().await.expect( "A system has panicked so the executor cannot continue.", ); - self.finish_system_and_signal_dependents(index); + self.finish_system_and_signal_dependents(completion); - while let Ok(index) = self.receiver.try_recv() { - self.finish_system_and_signal_dependents(index); + while let Ok(completion) = self.receiver.try_recv() { + self.finish_system_and_signal_dependents(completion); } self.rebuild_active_access(); @@ -250,6 +273,7 @@ impl MultiThreadedExecutor { completed_systems: FixedBitSet::new(), unapplied_systems: FixedBitSet::new(), apply_final_buffers: true, + // inlined_system_queue: Vec::new(), } } @@ -430,6 +454,35 @@ impl MultiThreadedExecutor { // SAFETY: this system is not running, no other reference exists let system = unsafe { &mut *systems[system_index].get() }; + let &SystemTaskMetadata { + is_send, + avg_execution_time, + .. + } = &self.system_task_metadata[system_index]; + if let Some(avg_execution_time) = avg_execution_time { + if is_send && avg_execution_time <= INLINE_THRESHOLD { + // #[cfg(feature = "trace")] + // let _system_span = info_span!("system", name = &*system.name()).entered(); + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + let start = Instant::now(); + // SAFETY: access is compatible + unsafe { system.run_unsafe((), world) }; + Instant::now() - start + })); + if let Ok(duration) = res { + self.finish_system_and_signal_dependents(SystemCompletion { + system_index, + duration, + }); + } else { + // close the channel to propagate the error to the + // multithreaded executor + self.sender.close(); + } + return; + } + } + #[cfg(feature = "trace")] let task_span = info_span!("system_task", name = &*system.name()); #[cfg(feature = "trace")] @@ -440,20 +493,25 @@ impl MultiThreadedExecutor { #[cfg(feature = "trace")] let system_guard = system_span.enter(); let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + let start = Instant::now(); // SAFETY: access is compatible unsafe { system.run_unsafe((), world) }; + Instant::now() - start })); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { + if let Ok(duration) = res { sender - .send(system_index) + .send(SystemCompletion { + system_index, + duration, + }) .await .unwrap_or_else(|error| unreachable!("{}", error)); + } else { + // close the channel to propagate the error to the + // multithreaded executor + sender.close(); } }; @@ -498,19 +556,24 @@ impl MultiThreadedExecutor { #[cfg(feature = "trace")] let system_guard = system_span.enter(); let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + let start = Instant::now(); apply_system_buffers(&unapplied_systems, systems, world); + Instant::now() - start })); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { + if let Ok(duration) = res { sender - .send(system_index) + .send(SystemCompletion { + system_index, + duration, + }) .await .unwrap_or_else(|error| unreachable!("{}", error)); + } else { + // close the channel to propagate the error to the + // multithreaded executor + sender.close(); } }; @@ -522,19 +585,24 @@ impl MultiThreadedExecutor { #[cfg(feature = "trace")] let system_guard = system_span.enter(); let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + let start = Instant::now(); system.run((), world); + Instant::now() - start })); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { + if let Ok(duration) = res { sender - .send(system_index) + .send(SystemCompletion { + system_index, + duration, + }) .await .unwrap_or_else(|error| unreachable!("{}", error)); + } else { + // close the channel to propagate the error to the + // multithreaded executor + sender.close(); } }; @@ -547,12 +615,18 @@ impl MultiThreadedExecutor { self.local_thread_running = true; } - fn finish_system_and_signal_dependents(&mut self, system_index: usize) { - if self.system_task_metadata[system_index].is_exclusive { + fn finish_system_and_signal_dependents(&mut self, completion: SystemCompletion) { + let SystemCompletion { + system_index, + duration, + } = completion; + let metadata = &mut self.system_task_metadata[system_index]; + metadata.update_exeuction_time(duration); + if metadata.is_exclusive { self.exclusive_running = false; } - if !self.system_task_metadata[system_index].is_send { + if !metadata.is_send { self.local_thread_running = false; }