diff --git a/crates/store/re_chunk_store/src/dataframe.rs b/crates/store/re_chunk_store/src/dataframe.rs index 2075d14f878e..283bf10313f8 100644 --- a/crates/store/re_chunk_store/src/dataframe.rs +++ b/crates/store/re_chunk_store/src/dataframe.rs @@ -104,6 +104,20 @@ impl Ord for TimeColumnDescriptor { } impl TimeColumnDescriptor { + fn metadata(&self) -> arrow2::datatypes::Metadata { + let Self { + timeline, + datatype: _, + } = self; + + std::iter::once(Some(( + "sorbet.index_name".to_owned(), + timeline.name().to_string(), + ))) + .flatten() + .collect() + } + #[inline] // Time column must be nullable since static data doesn't have a time. pub fn to_arrow_field(&self) -> Arrow2Field { @@ -113,6 +127,7 @@ impl TimeColumnDescriptor { datatype.clone(), true, /* nullable */ ) + .with_metadata(self.metadata()) } } diff --git a/rerun_py/rerun_sdk/rerun/any_value.py b/rerun_py/rerun_sdk/rerun/any_value.py index c50318aab895..699441856f8c 100644 --- a/rerun_py/rerun_sdk/rerun/any_value.py +++ b/rerun_py/rerun_sdk/rerun/any_value.py @@ -8,7 +8,7 @@ from rerun._baseclasses import ComponentDescriptor -from . import ComponentColumn +from ._baseclasses import ComponentColumn from ._log import AsComponents, ComponentBatchLike from .error_utils import catch_and_log_exceptions diff --git a/rerun_py/rerun_sdk/rerun/dataframe.py b/rerun_py/rerun_sdk/rerun/dataframe.py index 4b34e955a49f..b525bf1e9e81 100644 --- a/rerun_py/rerun_sdk/rerun/dataframe.py +++ b/rerun_py/rerun_sdk/rerun/dataframe.py @@ -1,5 +1,9 @@ from __future__ import annotations +from collections import defaultdict +from typing import Any, Optional + +import pyarrow as pa from rerun_bindings import ( ComponentColumnDescriptor as ComponentColumnDescriptor, ComponentColumnSelector as ComponentColumnSelector, @@ -18,3 +22,99 @@ ComponentLike as ComponentLike, ViewContentsLike as ViewContentsLike, ) + +from ._baseclasses import ComponentColumn, ComponentDescriptor +from ._log import IndicatorComponentBatch +from ._send_columns import TimeColumnLike, send_columns +from .recording_stream import RecordingStream + +SORBET_INDEX_NAME = b"sorbet.index_name" +SORBET_ENTITY_PATH = b"sorbet.path" +SORBET_ARCHETYPE_NAME = b"sorbet.semantic_family" +SORBET_ARCHETYPE_FIELD = b"sorbet.logical_type" +SORBET_COMPONENT_NAME = b"sorbet.semantic_type" +RERUN_KIND = b"rerun.kind" +RERUN_KIND_CONTROL = b"control" +RERUN_KIND_INDEX = b"time" + + +class RawIndexColumn(TimeColumnLike): + def __init__(self, metadata: dict[bytes, bytes], col: pa.Array): + self.metadata = metadata + self.col = col + + def timeline_name(self) -> str: + name = self.metadata.get(SORBET_INDEX_NAME, "unknown") + if isinstance(name, bytes): + name = name.decode("utf-8") + return name + + def as_arrow_array(self) -> pa.Array: + return self.col + + +class RawComponentBatchLike(ComponentColumn): + def __init__(self, metadata: dict[bytes, bytes], col: pa.Array): + self.metadata = metadata + self.col = col + + def component_descriptor(self) -> ComponentDescriptor: + kwargs = {} + if SORBET_ARCHETYPE_NAME in self.metadata: + kwargs["archetype_name"] = "rerun.archetypes" + self.metadata[SORBET_ARCHETYPE_NAME].decode("utf-8") + if SORBET_COMPONENT_NAME in self.metadata: + kwargs["component_name"] = "rerun.components." + self.metadata[SORBET_COMPONENT_NAME].decode("utf-8") + if SORBET_ARCHETYPE_FIELD in self.metadata: + kwargs["archetype_field_name"] = self.metadata[SORBET_ARCHETYPE_FIELD].decode("utf-8") + + if "component_name" not in kwargs: + kwargs["component_name"] = "Unknown" + + return ComponentDescriptor(**kwargs) + + def as_arrow_array(self) -> pa.Array: + return self.col + + +def send_record_batch(batch: pa.RecordBatch, rec: Optional[RecordingStream] = None) -> None: + """Coerce a single pyarrow `RecordBatch` to Rerun structure.""" + + indexes = [] + data: defaultdict[str, list[Any]] = defaultdict(list) + archetypes: defaultdict[str, set[Any]] = defaultdict(set) + for col in batch.schema: + metadata = col.metadata or {} + if metadata.get(RERUN_KIND) == RERUN_KIND_CONTROL: + continue + if SORBET_INDEX_NAME in metadata or metadata.get(RERUN_KIND) == RERUN_KIND_INDEX: + if SORBET_INDEX_NAME not in metadata: + metadata[SORBET_INDEX_NAME] = col.name + indexes.append(RawIndexColumn(metadata, batch.column(col.name))) + else: + entity_path = metadata.get(SORBET_ENTITY_PATH, col.name.split(":")[0]) + if isinstance(entity_path, bytes): + entity_path = entity_path.decode("utf-8") + data[entity_path].append(RawComponentBatchLike(metadata, batch.column(col.name))) + if SORBET_ARCHETYPE_NAME in metadata: + archetypes[entity_path].add(metadata[SORBET_ARCHETYPE_NAME].decode("utf-8")) + for entity_path, archetype_set in archetypes.items(): + for archetype in archetype_set: + data[entity_path].append(IndicatorComponentBatch("rerun.archetypes." + archetype)) + + for entity_path, columns in data.items(): + send_columns( + entity_path, + indexes, + columns, + # This is fine, send_columns will handle the conversion + recording=rec, # NOLINT + ) + + +def send_dataframe(df: pa.RecordBatchReader | pa.Table, rec: Optional[RecordingStream] = None) -> None: + """Coerce a pyarrow `RecordBatchReader` or `Table` to Rerun structure.""" + if isinstance(df, pa.Table): + df = df.to_reader() + + for batch in df: + send_record_batch(batch, rec) diff --git a/rerun_py/tests/unit/test_dataframe.py b/rerun_py/tests/unit/test_dataframe.py index 69daa71d1da3..067f7fa63b34 100644 --- a/rerun_py/tests/unit/test_dataframe.py +++ b/rerun_py/tests/unit/test_dataframe.py @@ -380,3 +380,19 @@ def test_view_syntax(self) -> None: table = pa.Table.from_batches(batches, batches.schema) assert table.num_columns == 3 assert table.num_rows == 0 + + def test_roundtrip_send(self) -> None: + df = self.recording.view(index="my_index", contents="/**").select().read_all() + + with tempfile.TemporaryDirectory() as tmpdir: + rrd = tmpdir + "/tmp.rrd" + + rr.init("rerun_example_test_recording") + rr.dataframe.send_dataframe(df) + rr.save(rrd) + + round_trip_recording = rr.dataframe.load_recording(rrd) + + df_round_trip = round_trip_recording.view(index="my_index", contents="/**").select().read_all() + + assert df == df_round_trip