Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the need for collect_inner() on QuerySortedManyIter #16650

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 182 additions & 39 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use core::{
};

use super::{QueryData, QueryFilter, ReadOnlyQueryData};
use alloc::vec::IntoIter;

/// An [`Iterator`] over query results of a [`Query`](crate::system::Query).
///
Expand Down Expand Up @@ -1453,8 +1452,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
/// # let entity_list: Vec<Entity> = Vec::new();
/// // We need to collect the internal iterator before iterating mutably
/// let mut parent_query_iter = query.iter_many_mut(entity_list)
/// .sort::<Entity>()
/// .collect_inner();
/// .sort::<Entity>();
///
/// let mut scratch_value = 0;
/// while let Some(mut part_value) = parent_query_iter.fetch_next_back()
Expand Down Expand Up @@ -1496,11 +1494,43 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};

// Wrap the `L::Item` in a `MaybeUninit` to avoid aliasing issues.
//
// `L::Item` and `D::Item` may have conflicting access,
// so we must ensure that we never have both alive for the same entity.
// The same entity may appear multiple times in `self.entity_iter`,
// so we must ensure that no `L::Item` is alive when we create `D::Item`s in `fetch_next`.
// The `entity_list` we return is a `Map<vec::IntoIter>` that owns the contents of `keyed_query`,
// so it would keep the `L::Item`s alive if we store them in `keyed_query` directly.
// The safest solution is to `collect()` the `Map` into a fresh `Vec<Entity>`,
// but we want to avoid the overhead of doing a new allocation and copying the data.
// We could re-use a single allocation by storing `Option<L::Item>` and setting them to `None`,
// but that would still have overhead of writing the `None` values.
// Instead, we store `MaybeUninit<L::Item>`, which acts similar to `Option`.
// The compiler can never assume the item is initialized, so we don't need to do anything to uninitialize it.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (key, NeutralOrd(entity)))
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort();
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity.0);
keyed_query.sort_by(|(key_1, _), (key_2, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let key_1 = unsafe { key_1.assume_init_ref() };
// SAFETY: The item was originally initialized, and has never been dropped
let key_2 = unsafe { key_2.assume_init_ref() };
key_1.cmp(key_2)
});

// Run any `L::Item` drop glue before returning, since `MaybeUninit` cannot run it for us.
// Following this loop, we consider the items to all be uninitialized.
//
// Note that most `QueryData` items, including `&C` and `Mut<C>`, have no drop glue, so this is usually a no-op.
// Also note that doing this here means the items will leak if the `sort` method panics.
// That could be resolved by wrapping them in a new type with a `Drop` impl.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
// Each lens query item is dropped before the respective actual query item is accessed.
Expand Down Expand Up @@ -1585,11 +1615,23 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (key, NeutralOrd(entity)))
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_unstable();
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity.0);
keyed_query.sort_unstable_by(|(key_1, _), (key_2, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let key_1 = unsafe { key_1.assume_init_ref() };
// SAFETY: The item was originally initialized, and has never been dropped
let key_2 = unsafe { key_2.assume_init_ref() };
key_1.cmp(key_2)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
// Each lens query item is dropped before the respective actual query item is accessed.
Expand Down Expand Up @@ -1679,8 +1721,22 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
let mut keyed_query: Vec<_> = query_lens.collect();
keyed_query.sort_by(|(key_1, _), (key_2, _)| compare(key_1, key_2));
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_by(|(key_1, _), (key_2, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let key_1 = unsafe { key_1.assume_init_ref() };
// SAFETY: The item was originally initialized, and has never been dropped
let key_2 = unsafe { key_2.assume_init_ref() };
compare(key_1, key_2)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
Expand Down Expand Up @@ -1737,8 +1793,22 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
let mut keyed_query: Vec<_> = query_lens.collect();
keyed_query.sort_unstable_by(|(key_1, _), (key_2, _)| compare(key_1, key_2));
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_unstable_by(|(key_1, _), (key_2, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let key_1 = unsafe { key_1.assume_init_ref() };
// SAFETY: The item was originally initialized, and has never been dropped
let key_2 = unsafe { key_2.assume_init_ref() };
compare(key_1, key_2)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
Expand Down Expand Up @@ -1861,8 +1931,20 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
let mut keyed_query: Vec<_> = query_lens.collect();
keyed_query.sort_by_key(|(lens, _)| f(lens));
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_by_key(|(lens, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let lens = unsafe { lens.assume_init_ref() };
f(lens)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
Expand Down Expand Up @@ -1922,8 +2004,20 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
let mut keyed_query: Vec<_> = query_lens.collect();
keyed_query.sort_unstable_by_key(|(lens, _)| f(lens));
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_unstable_by_key(|(lens, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let lens = unsafe { lens.assume_init_ref() };
f(lens)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
Expand Down Expand Up @@ -1983,8 +2077,20 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item: Borrow<Entity>>>
world.change_tick(),
)
};
let mut keyed_query: Vec<_> = query_lens.collect();
keyed_query.sort_by_cached_key(|(lens, _)| f(lens));
// See the comments in `sort()` for an explanation of `MaybeUninit`.
let mut keyed_query: Vec<_> = query_lens
.map(|(key, entity)| (MaybeUninit::new(key), entity))
.collect();
keyed_query.sort_by_cached_key(|(lens, _)| {
// SAFETY: The item was originally initialized, and has never been dropped
let lens = unsafe { lens.assume_init_ref() };
f(lens)
});
// See the comments in `sort()` for an explanation of `MaybeUninit`.
for (item, _) in &mut keyed_query {
// SAFETY: The item was originally initialized, and has never been dropped
unsafe { item.assume_init_drop() };
}
let entity_iter = keyed_query.into_iter().map(|(.., entity)| entity);
// SAFETY:
// `self.world` has permission to access the required components.
Expand Down Expand Up @@ -2178,25 +2284,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter, I: Iterator<Item = Entity>>
unsafe { D::fetch(&mut self.fetch, entity, location.table_row) }
}

/// Collects the internal [`I`](QuerySortedManyIter) once.
/// [`fetch_next`](QuerySortedManyIter) and [`fetch_next_back`](QuerySortedManyIter) require this to be called first.
#[inline(always)]
pub fn collect_inner(self) -> QuerySortedManyIter<'w, 's, D, F, IntoIter<Entity>> {
QuerySortedManyIter {
entity_iter: self.entity_iter.collect::<Vec<_>>().into_iter(),
entities: self.entities,
tables: self.tables,
archetypes: self.archetypes,
fetch: self.fetch,
query_state: self.query_state,
}
}
}

impl<'w, 's, D: QueryData, F: QueryFilter> QuerySortedManyIter<'w, 's, D, F, IntoIter<Entity>> {
/// Get next result from the query
/// [`collect_inner`](QuerySortedManyIter) needs to be called before this method becomes available.
/// This is done to prevent mutable aliasing.
#[inline(always)]
pub fn fetch_next(&mut self) -> Option<D::Item<'_>> {
let entity = self.entity_iter.next()?;
Expand All @@ -2210,10 +2298,12 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QuerySortedManyIter<'w, 's, D, F, Int
// `entity` is passed from `entity_iter` the first time.
unsafe { D::shrink(self.fetch_next_aliased_unchecked(entity)).into() }
}
}

impl<'w, 's, D: QueryData, F: QueryFilter, I: DoubleEndedIterator<Item = Entity>>
QuerySortedManyIter<'w, 's, D, F, I>
{
/// Get next result from the query
/// [`collect_inner`](QuerySortedManyIter) needs to be called before this method becomes available.
/// This is done to prevent mutable aliasing.
#[inline(always)]
pub fn fetch_next_back(&mut self) -> Option<D::Item<'_>> {
let entity = self.entity_iter.next_back()?;
Expand Down Expand Up @@ -3091,4 +3181,57 @@ mod tests {

iter_2.sort::<Entity>();
}

// This test should be run with miri to check for UB caused by aliasing.
// The lens items created during the sort must not be live at the same time as the mutable references returned from the iterator.
#[test]
fn query_iter_many_sorts_duplicate_entities_no_ub() {
#[derive(Component, Ord, PartialOrd, Eq, PartialEq)]
struct C(usize);

let mut world = World::new();
let id = world.spawn(C(10)).id();
let mut query_state = world.query::<&mut C>();

{
let mut query = query_state.iter_many_mut(&mut world, [id, id]).sort::<&C>();
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_unstable::<&C>();
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_by::<&C>(Ord::cmp);
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_unstable_by::<&C>(Ord::cmp);
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_by_key::<&C, _>(|d| d.0);
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_unstable_by_key::<&C, _>(|d| d.0);
while query.fetch_next().is_some() {}
}
{
let mut query = query_state
.iter_many_mut(&mut world, [id, id])
.sort_by_cached_key::<&C, _>(|d| d.0);
while query.fetch_next().is_some() {}
}
}
}
Loading