Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-supplied executors #3091

Merged
merged 10 commits into from
Oct 16, 2024
95 changes: 38 additions & 57 deletions any_spawner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
use std::{future::Future, pin::Pin, sync::OnceLock};
use thiserror::Error;

pub(crate) type PinnedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub(crate) type PinnedLocalFuture<T> = Pin<Box<dyn Future<Output = T>>>;
/// A future that has been pinned.
pub type PinnedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
/// A future that has been pinned.
stefnotch marked this conversation as resolved.
Show resolved Hide resolved
pub type PinnedLocalFuture<T> = Pin<Box<dyn Future<Output = T>>>;

static SPAWN: OnceLock<fn(PinnedFuture<()>)> = OnceLock::new();
static SPAWN_LOCAL: OnceLock<fn(PinnedLocalFuture<()>)> = OnceLock::new();
Expand Down Expand Up @@ -284,63 +286,42 @@ impl Executor {
.map_err(|_| ExecutorError::AlreadySet)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
#[cfg(feature = "futures-executor")]
#[test]
fn can_spawn_local_future() {
use crate::Executor;
use std::rc::Rc;
_ = Executor::init_futures_executor();
let rc = Rc::new(());
Executor::spawn_local(async {
_ = rc;
});
Executor::spawn(async {});
}
/// Globally sets a custom executor as the executor used to spawn tasks.
///
/// Returns `Err(_)` if an executor has already been set.
pub fn init_custom_executor(
custom_executor: impl CustomExecutor + 'static,
) -> Result<(), ExecutorError> {
static EXECUTOR: OnceLock<Box<dyn CustomExecutor>> = OnceLock::new();
EXECUTOR
.set(Box::new(custom_executor))
.map_err(|_| ExecutorError::AlreadySet)?;

#[cfg(feature = "futures-executor")]
#[test]
fn can_make_threaded_progress() {
use crate::Executor;
use std::sync::{atomic::AtomicUsize, Arc};
_ = Executor::init_futures_executor();
let counter = Arc::new(AtomicUsize::new(0));
Executor::spawn({
let counter = Arc::clone(&counter);
async move {
assert_eq!(
counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel),
0
);
}
});
futures::executor::block_on(Executor::tick());
assert_eq!(counter.load(std::sync::atomic::Ordering::Acquire), 1);
SPAWN
.set(|fut| {
EXECUTOR.get().unwrap().spawn(fut);
})
.map_err(|_| ExecutorError::AlreadySet)?;
SPAWN_LOCAL
.set(|fut| EXECUTOR.get().unwrap().spawn_local(fut))
.map_err(|_| ExecutorError::AlreadySet)?;
POLL_LOCAL
.set(|| EXECUTOR.get().unwrap().poll_local())
.map_err(|_| ExecutorError::AlreadySet)?;
Ok(())
}
}

#[cfg(feature = "futures-executor")]
#[test]
fn can_make_local_progress() {
use crate::Executor;
use std::sync::{atomic::AtomicUsize, Arc};
_ = Executor::init_futures_executor();
let counter = Arc::new(AtomicUsize::new(0));
Executor::spawn_local({
let counter = Arc::clone(&counter);
async move {
assert_eq!(
counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel),
0
);
Executor::spawn_local(async {
// Should not crash
});
}
});
Executor::poll_local();
assert_eq!(counter.load(std::sync::atomic::Ordering::Acquire), 1);
}
/// A trait for custom executors.
/// Custom executors can be used to integrate with any executor that supports spawning futures.
///
/// All methods can be called recursively.
pub trait CustomExecutor: Send + Sync {
/// Spawns a future, usually on a thread pool.
fn spawn(&self, fut: PinnedFuture<()>);
/// Spawns a local future. May require calling `poll_local` to make progress.
fn spawn_local(&self, fut: PinnedLocalFuture<()>);
/// Polls the executor, if it supports polling.
fn poll_local(&self);
}
55 changes: 55 additions & 0 deletions any_spawner/tests/custom_runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#[cfg(feature = "futures-executor")]
use any_spawner::{CustomExecutor, Executor, PinnedFuture, PinnedLocalFuture};
#[cfg(feature = "futures-executor")]
#[test]
fn can_create_custom_executor() {
use futures::{
executor::{LocalPool, LocalSpawner},
task::LocalSpawnExt,
};
use std::{
cell::RefCell,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

thread_local! {
static LOCAL_POOL: RefCell<LocalPool> = RefCell::new(LocalPool::new());
static SPAWNER: LocalSpawner = LOCAL_POOL.with(|pool| pool.borrow().spawner());
}

struct CustomFutureExecutor;
impl CustomExecutor for CustomFutureExecutor {
fn spawn(&self, _fut: PinnedFuture<()>) {
panic!("not supported in this test");
}

fn spawn_local(&self, fut: PinnedLocalFuture<()>) {
SPAWNER.with(|spawner| {
spawner.spawn_local(fut).expect("failed to spawn future");
});
}

fn poll_local(&self) {
LOCAL_POOL.with(|pool| {
if let Ok(mut pool) = pool.try_borrow_mut() {
pool.run_until_stalled();
}
// If we couldn't borrow_mut, we're in a nested call to poll, so we don't need to do anything.
});
}
}

Executor::init_custom_executor(CustomFutureExecutor)
.expect("couldn't set executor");

let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
Executor::spawn_local(async move {
counter_clone.store(1, Ordering::Release);
});
Executor::poll_local();
assert_eq!(counter.load(Ordering::Acquire), 1);
}
38 changes: 38 additions & 0 deletions any_spawner/tests/futures_runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#[cfg(feature = "futures-executor")]
use any_spawner::Executor;
// All tests in this file use the same executor.

#[cfg(feature = "futures-executor")]
#[test]
fn can_spawn_local_future() {
use std::rc::Rc;

let _ = Executor::init_futures_executor();
let rc = Rc::new(());
Executor::spawn_local(async {
_ = rc;
});
Executor::spawn(async {});
}

#[cfg(feature = "futures-executor")]
#[test]
fn can_make_local_progress() {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
let _ = Executor::init_futures_executor();
let counter = Arc::new(AtomicUsize::new(0));
Executor::spawn_local({
let counter = Arc::clone(&counter);
async move {
assert_eq!(counter.fetch_add(1, Ordering::AcqRel), 0);
Executor::spawn_local(async {
// Should not crash
});
}
});
Executor::poll_local();
assert_eq!(counter.load(Ordering::Acquire), 1);
}