Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataframe v2: make num_rows aware of filters and such #7621

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 88 additions & 9 deletions crates/store/re_dataframe2/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::{
use ahash::HashSet;
use arrow2::{
array::Array as ArrowArray, chunk::Chunk as ArrowChunk, datatypes::Schema as ArrowSchema,
Either,
};
use itertools::Itertools;

Expand Down Expand Up @@ -376,22 +377,44 @@ impl QueryHandle<'_> {
pub fn num_rows(&self) -> u64 {
re_tracing::profile_function!();

let all_unique_timestamps: HashSet<TimeInt> = self
.init()
.view_chunks
.iter()
let state = self.init();

let mut view_chunks = state.view_chunks.iter();
let view_chunks = if let Some(view_pov_chunks_idx) = state.view_pov_chunks_idx {
Either::Left(view_chunks.nth(view_pov_chunks_idx).into_iter())
} else {
Either::Right(view_chunks)
};

let mut all_unique_timestamps: HashSet<TimeInt> = view_chunks
.flat_map(|chunks| {
chunks.iter().filter_map(|(_cursor, chunk)| {
chunk
.timelines()
.get(&self.query.filtered_index)
.map(|time_column| time_column.times())
if chunk.is_static() {
Some(Either::Left(std::iter::once(TimeInt::STATIC)))
} else {
chunk
.timelines()
.get(&self.query.filtered_index)
.map(|time_column| Either::Right(time_column.times()))
}
})
})
.flatten()
.collect();

all_unique_timestamps.len() as _
if let Some(filtered_index_values) = self.query.filtered_index_values.as_ref() {
all_unique_timestamps.retain(|time| filtered_index_values.contains(time));
}

let num_rows = all_unique_timestamps.len() as _;

if cfg!(debug_assertions) {
let expected_num_rows =
self.engine.query(self.query.clone()).into_iter().count() as u64;
assert_eq!(expected_num_rows, num_rows);
}

num_rows
}

/// Returns the next row's worth of data.
Expand Down Expand Up @@ -886,6 +909,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -928,6 +955,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -970,6 +1001,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1018,6 +1053,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1068,6 +1107,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand All @@ -1091,6 +1134,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand All @@ -1114,6 +1161,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1147,6 +1198,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1198,6 +1253,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1231,6 +1290,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1277,6 +1340,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1306,6 +1373,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1354,6 +1425,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down Expand Up @@ -1434,6 +1509,10 @@ mod tests {
eprintln!("{query:#?}:");

let query_handle = query_engine.query(query.clone());
assert_eq!(
query_engine.query(query.clone()).into_iter().count() as u64,
query_handle.num_rows()
);
let dataframe = concatenate_record_batches(
query_handle.schema().clone(),
&query_handle.into_batch_iter().collect_vec(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ impl Query {
) -> Result<Option<components::FilterByEvent>, SpaceViewSystemExecutionError> {
Ok(self
.query_property
.component_or_empty::<components::FilterByEvent>()?)
.component_or_empty::<components::FilterByEvent>()?
.filter(|filter_by_event| filter_by_event.active()))
}

pub(super) fn save_filter_by_event(
Expand Down
10 changes: 0 additions & 10 deletions rerun_py/rerun_bindings/rerun_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@ import pyarrow as pa

from .types import AnyColumn, ComponentLike, ViewContentsLike

class ControlColumnDescriptor:
"""A control-level column such as `RowId`."""

class ControlColumnSelector:
"""A selector for a control column."""

@staticmethod
def row_id() -> ControlColumnSelector: ...

class IndexColumnDescriptor:
"""A column containing the index values for when the component data was updated."""

Expand All @@ -35,7 +26,6 @@ class ComponentColumnSelector:
class Schema:
"""The schema representing all columns in a [`Recording`][]."""

def control_columns(self) -> list[ControlColumnDescriptor]: ...
def index_columns(self) -> list[IndexColumnDescriptor]: ...
def component_columns(self) -> list[ComponentColumnDescriptor]: ...
def column_for(self, entity_path: str, component: ComponentLike) -> Optional[ComponentColumnDescriptor]: ...
Expand Down
5 changes: 0 additions & 5 deletions rerun_py/rerun_bindings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@
from .rerun_bindings import (
ComponentColumnDescriptor as ComponentColumnDescriptor,
ComponentColumnSelector as ComponentColumnSelector,
ControlColumnDescriptor as ControlColumnDescriptor,
ControlColumnSelector as ControlColumnSelector,
TimeColumnDescriptor as TimeColumnDescriptor,
TimeColumnSelector as TimeColumnSelector,
)


ComponentLike: TypeAlias = Union[str, type["ComponentMixin"]]

AnyColumn: TypeAlias = Union[
"ControlColumnDescriptor",
"TimeColumnDescriptor",
"ComponentColumnDescriptor",
"ControlColumnSelector",
"TimeColumnSelector",
"ComponentColumnSelector",
]
Expand Down
2 changes: 0 additions & 2 deletions rerun_py/rerun_sdk/rerun/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from rerun_bindings import (
ComponentColumnDescriptor as ComponentColumnDescriptor,
ComponentColumnSelector as ComponentColumnSelector,
ControlColumnDescriptor as ControlColumnDescriptor,
ControlColumnSelector as ControlColumnSelector,
Recording as Recording,
RRDArchive as RRDArchive,
Schema as Schema,
Expand Down
Loading