Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline "small" systems into the multithreaded executor #7693

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 102 additions & 28 deletions crates/bevy_ecs/src/schedule/executor/multi_threaded.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::sync::Arc;

use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor};
use bevy_utils::default;
use bevy_utils::syncunsafecell::SyncUnsafeCell;
#[cfg(feature = "trace")]
use bevy_utils::tracing::{info_span, Instrument};
use bevy_utils::{default, Duration, Instant};
use std::panic::AssertUnwindSafe;

use async_channel::{Receiver, Sender};
Expand All @@ -23,6 +23,14 @@ use crate::{

use crate as bevy_ecs;

const EXECUTION_TIME_DECAY_RATE: f32 = 0.25;
const INLINE_THRESHOLD: Duration = Duration::from_micros(2);

struct SystemCompletion {
system_index: usize,
duration: Duration,
}

/// A funky borrow split of [`SystemSchedule`] required by the [`MultiThreadedExecutor`].
struct SyncUnsafeSchedule<'a> {
systems: &'a [SyncUnsafeCell<BoxedSystem>],
Expand Down Expand Up @@ -61,14 +69,26 @@ struct SystemTaskMetadata {
is_send: bool,
/// Is `true` if the system is exclusive.
is_exclusive: bool,
/// Weighted average of the execution time of the system.
avg_execution_time: Option<Duration>,
}

impl SystemTaskMetadata {
fn update_exeuction_time(&mut self, mut execution_time: Duration) {
if let Some(avg_execution_time) = self.avg_execution_time {
execution_time = execution_time.mul_f32(EXECUTION_TIME_DECAY_RATE)
+ avg_execution_time.mul_f32(1.0 - EXECUTION_TIME_DECAY_RATE);
}
self.avg_execution_time = Some(execution_time);
}
}

/// Runs the schedule using a thread pool. Non-conflicting systems can run in parallel.
pub struct MultiThreadedExecutor {
/// Sends system completion events.
sender: Sender<usize>,
sender: Sender<SystemCompletion>,
/// Receives system completion events.
receiver: Receiver<usize>,
receiver: Receiver<SystemCompletion>,
/// Metadata for scheduling and running system tasks.
system_task_metadata: Vec<SystemTaskMetadata>,
/// Union of the accesses of all currently running systems.
Expand Down Expand Up @@ -99,6 +119,8 @@ pub struct MultiThreadedExecutor {
unapplied_systems: FixedBitSet,
/// Setting when true applies system buffers after all systems have run
apply_final_buffers: bool,
// A queue of inlined systems that have yet to run.
// inlined_system_queue: Vec<usize>,
}

impl Default for MultiThreadedExecutor {
Expand Down Expand Up @@ -136,6 +158,7 @@ impl SystemExecutor for MultiThreadedExecutor {
dependents: schedule.system_dependents[index].clone(),
is_send: schedule.systems[index].is_send(),
is_exclusive: schedule.systems[index].is_exclusive(),
avg_execution_time: None,
});
}

Expand Down Expand Up @@ -186,15 +209,15 @@ impl SystemExecutor for MultiThreadedExecutor {

if self.num_running_systems > 0 {
// wait for systems to complete
let index =
let completion =
self.receiver.recv().await.expect(
"A system has panicked so the executor cannot continue.",
);

self.finish_system_and_signal_dependents(index);
self.finish_system_and_signal_dependents(completion);

while let Ok(index) = self.receiver.try_recv() {
self.finish_system_and_signal_dependents(index);
while let Ok(completion) = self.receiver.try_recv() {
self.finish_system_and_signal_dependents(completion);
}

self.rebuild_active_access();
Expand Down Expand Up @@ -250,6 +273,7 @@ impl MultiThreadedExecutor {
completed_systems: FixedBitSet::new(),
unapplied_systems: FixedBitSet::new(),
apply_final_buffers: true,
// inlined_system_queue: Vec::new(),
}
}

Expand Down Expand Up @@ -430,6 +454,35 @@ impl MultiThreadedExecutor {
// SAFETY: this system is not running, no other reference exists
let system = unsafe { &mut *systems[system_index].get() };

let &SystemTaskMetadata {
is_send,
avg_execution_time,
..
} = &self.system_task_metadata[system_index];
if let Some(avg_execution_time) = avg_execution_time {
if is_send && avg_execution_time <= INLINE_THRESHOLD {
// #[cfg(feature = "trace")]
// let _system_span = info_span!("system", name = &*system.name()).entered();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let start = Instant::now();
// SAFETY: access is compatible
unsafe { system.run_unsafe((), world) };
Instant::now() - start
}));
if let Ok(duration) = res {
self.finish_system_and_signal_dependents(SystemCompletion {
system_index,
duration,
});
} else {
// close the channel to propagate the error to the
// multithreaded executor
self.sender.close();
}
return;
}
}

#[cfg(feature = "trace")]
let task_span = info_span!("system_task", name = &*system.name());
#[cfg(feature = "trace")]
Expand All @@ -440,20 +493,25 @@ impl MultiThreadedExecutor {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let start = Instant::now();
// SAFETY: access is compatible
unsafe { system.run_unsafe((), world) };
Instant::now() - start
}));
#[cfg(feature = "trace")]
drop(system_guard);
if res.is_err() {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
} else {
if let Ok(duration) = res {
sender
.send(system_index)
.send(SystemCompletion {
system_index,
duration,
})
.await
.unwrap_or_else(|error| unreachable!("{}", error));
} else {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
}
};

Expand Down Expand Up @@ -498,19 +556,24 @@ impl MultiThreadedExecutor {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let start = Instant::now();
apply_system_buffers(&unapplied_systems, systems, world);
Instant::now() - start
}));
#[cfg(feature = "trace")]
drop(system_guard);
if res.is_err() {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
} else {
if let Ok(duration) = res {
sender
.send(system_index)
.send(SystemCompletion {
system_index,
duration,
})
.await
.unwrap_or_else(|error| unreachable!("{}", error));
} else {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
}
};

Expand All @@ -522,19 +585,24 @@ impl MultiThreadedExecutor {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let start = Instant::now();
system.run((), world);
Instant::now() - start
}));
#[cfg(feature = "trace")]
drop(system_guard);
if res.is_err() {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
} else {
if let Ok(duration) = res {
sender
.send(system_index)
.send(SystemCompletion {
system_index,
duration,
})
.await
.unwrap_or_else(|error| unreachable!("{}", error));
} else {
// close the channel to propagate the error to the
// multithreaded executor
sender.close();
}
};

Expand All @@ -547,12 +615,18 @@ impl MultiThreadedExecutor {
self.local_thread_running = true;
}

fn finish_system_and_signal_dependents(&mut self, system_index: usize) {
if self.system_task_metadata[system_index].is_exclusive {
fn finish_system_and_signal_dependents(&mut self, completion: SystemCompletion) {
let SystemCompletion {
system_index,
duration,
} = completion;
let metadata = &mut self.system_task_metadata[system_index];
metadata.update_exeuction_time(duration);
if metadata.is_exclusive {
self.exclusive_running = false;
}

if !self.system_task_metadata[system_index].is_send {
if !metadata.is_send {
self.local_thread_running = false;
}

Expand Down