Skip to content

Commit

Permalink
Rework tuple::join impl
Browse files Browse the repository at this point in the history
  • Loading branch information
matheus-consoli committed Nov 15, 2022
1 parent be1d966 commit 32a5233
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 79 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ harness = false
[dependencies]
bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] }
futures-core = "0.3"
paste = "1.0.9"
pin-project = "1.0.8"

[dev-dependencies]
Expand Down
182 changes: 104 additions & 78 deletions src/future/join/tuple.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,101 @@
use super::Join as JoinTrait;
use crate::utils::PollState;
use crate::utils::{self, PollState, PollStates};

use core::fmt::{self, Debug};
use core::future::{Future, IntoFuture};
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::task::{Context, Poll};

use pin_project::{pin_project, pinned_drop};
use pin_project::pin_project;

use paste::paste;
macro_rules! maybe_poll {
($idx:tt, $len:ident, $this:ident, $fut:ident, $cx:ident) => {
if $this.states[$idx].is_pending() {
if let Poll::Ready(out) = $this.$fut.poll($cx) {
$this.outputs.$idx = MaybeUninit::new(out);
$this.states[$idx] = PollState::Done;
*$this.len -= 1;
}
}
};
}

macro_rules! poll_all_pending {
(@inner 0, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(0, $len, $this, $fut, $cx);
poll_all_pending!(@inner 1, $len, $this, $cx, ($($rest,)*));
};
(@inner 1, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(1, $len, $this, $fut, $cx);
poll_all_pending!(@inner 2, $len, $this, $cx, ($($rest,)*));
};
(@inner 2, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(2, $len, $this, $fut, $cx);
poll_all_pending!(@inner 3, $len, $this, $cx, ($($rest,)*));
};
(@inner 3, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(3, $len, $this, $fut, $cx);
poll_all_pending!(@inner 4, $len, $this, $cx, ($($rest,)*));
};
(@inner 4, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(4, $len, $this, $fut, $cx);
poll_all_pending!(@inner 5, $len, $this, $cx, ($($rest,)*));
};
(@inner 5, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(5, $len, $this, $fut, $cx);
poll_all_pending!(@inner 6, $len, $this, $cx, ($($rest,)*));
};
(@inner 6, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(6, $len, $this, $fut, $cx);
poll_all_pending!(@inner 7, $len, $this, $cx, ($($rest,)*));
};
(@inner 7, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(7, $len, $this, $fut, $cx);
poll_all_pending!(@inner 8, $len, $this, $cx, ($($rest,)*));
};
(@inner 8, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(8, $len, $this, $fut, $cx);
poll_all_pending!(@inner 9, $len, $this, $cx, ($($rest,)*));
};
(@inner 9, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(9, $len, $this, $fut, $cx);
poll_all_pending!(@inner 10, $len, $this, $cx, ($($rest,)*));
};
(@inner 10, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(10, $len, $this, $fut, $cx);
poll_all_pending!(@inner 11, $len, $this, $cx, ($($rest,)*));
};
(@inner 11, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(11, $len, $this, $fut, $cx);
poll_all_pending!(@inner 12, $len, $this, $cx, ($($rest,)*));
};
(@inner 12, $len:ident, $this:ident, $cx:ident, ($fut:ident, $($rest:ident,)*)) => {
maybe_poll!(12, $len, $this, $fut, $cx);
};
(@inner $ignore:literal, $len:ident, $this:ident, $cx:ident, ()) => { };
($len:ident, $this:ident, $cx:ident, $($F:ident,)*) => {
poll_all_pending!(@inner 0, $len, $this, $cx, ($($F,)*));
};
}

macro_rules! impl_join_tuple {
($StructName:ident $($F:ident)*) => {
paste!{
/// Waits for two similarly-typed futures to complete.
///
/// This `struct` is created by the [`join`] method on the [`Join`] trait. See
/// its documentation for more.
///
/// [`join`]: crate::future::Join::join
/// [`Join`]: crate::future::Join
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName<$($F: Future),*> {
done: bool,
$(
#[pin] $F: $F,
[<$F _out>]: MaybeUninit<$F::Output>,
[<$F _state>]: PollState,
)*
}
/// Waits for two similarly-typed futures to complete.
///
/// This `struct` is created by the [`join`] method on the [`Join`] trait. See
/// its documentation for more.
///
/// [`join`]: crate::future::Join::join
/// [`Join`]: crate::future::Join
#[pin_project]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[allow(non_snake_case)]
pub struct $StructName<$($F: Future),*> {
len: u32,
$(#[pin] $F: $F,)*
outputs: ($(MaybeUninit<$F::Output>,)*),
states: PollStates,
}

impl<$($F),*> Debug for $StructName<$($F),*>
Expand All @@ -40,45 +104,29 @@ macro_rules! impl_join_tuple {
$F::Output: Debug,
)* {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
paste!{
f.debug_tuple("Join")
$(.field(&self.$F).field(&self.[<$F _state >]))*
.finish()
}
f.debug_tuple("Join")
.field(&($(&self.$F,)*))
.field(&self.states)
.finish()
}
}

#[allow(unused_mut)]
#[allow(unused_parens)]
#[allow(unused_variables)]
impl<$($F: Future),*> Future for $StructName<$($F),*> {
type Output = ($($F::Output),*);
type Output = ($($F::Output,)*);

fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
let mut all_done = true;
let mut this = self.project();
assert!(!*this.done, "Futures must not be polled after completing");

// Poll futures
paste! {
$(
if this.[<$F _state>].is_pending() {
if let Poll::Ready(out) = this.$F.poll(cx) {
*this.[<$F _out>] = MaybeUninit::new(out);
*this.[<$F _state>] = PollState::Done;
}
}
all_done &= this.[<$F _state>].is_done();
)*
}

if all_done {
*this.done = true;
paste! {
Poll::Ready(($( unsafe { this.[<$F _out>].assume_init_read() }),*))
}
poll_all_pending!(LEN, this, cx, $($F,)*);

if *this.len <= 0 {
let out = unsafe {(this.outputs as *const _ as *const ($($F::Output,)*)).read()};
Poll::Ready(out)
} else {
Poll::Pending
}
Expand All @@ -90,39 +138,17 @@ macro_rules! impl_join_tuple {
where $(
$F: IntoFuture,
)* {
type Output = ($($F::Output),*);
type Output = ($($F::Output,)*);
type Future = $StructName<$($F::IntoFuture),*>;

fn join(self) -> Self::Future {
let ($($F,)*): ($($F,)*) = self;
paste! {
$StructName {
done: false,
$(
$F: $F.into_future(),
[<$F _out>]: MaybeUninit::uninit(),
[<$F _state>]: PollState::default(),
)*
}
}
}
}

#[pinned_drop]
impl<$($F,)*> PinnedDrop for $StructName<$($F,)*>
where $(
$F: Future,
)* {
fn drop(self: Pin<&mut Self>) {
let _this = self.project();

paste! {
$(
if _this.[<$F _state>].is_done() {
// SAFETY: if the future is marked as done, we can safelly drop its out
unsafe { _this.[<$F _out>].assume_init_drop() };
}
)*
const LEN: u32 = utils::tuple_len!($($F,)*);
$StructName {
len: LEN,
$($F: $F.into_future(),)*
outputs: ($(MaybeUninit::<$F::Output>::uninit(),)*),
states: PollStates::new(LEN as usize),
}
}
}
Expand Down Expand Up @@ -159,7 +185,7 @@ mod test {
fn join_1() {
futures_lite::future::block_on(async {
let a = future::ready("hello");
assert_eq!((a,).join().await, ("hello"));
assert_eq!((a,).join().await, ("hello",));
});
}

Expand Down
13 changes: 13 additions & 0 deletions src/utils/poll_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ pub(crate) enum PollStates {
Boxed(Box<[PollState]>),
}

impl core::fmt::Debug for PollStates {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Inline(len, states) => f
// .debug_tuple("Inline")
.debug_list()
.entries(&states[..(*len as usize)])
.finish(),
Self::Boxed(states) => f.debug_list().entries(&**states).finish(),
}
}
}

impl PollStates {
pub(crate) fn new(len: usize) -> Self {
assert!(MAX_INLINE_ENTRIES <= u8::MAX as usize);
Expand Down
4 changes: 4 additions & 0 deletions src/utils/tuple.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
/// Compute the number of permutations for a number
/// during compilation.
pub(crate) const fn permutations(mut num: u32) -> u32 {
if num == 0 {
return 0;
}

let mut total = 1;
loop {
total *= num;
Expand Down

0 comments on commit 32a5233

Please sign in to comment.