Skip to content

Commit

Permalink
Introduce new API to send a dataframe to Rerun (#8461)
Browse files Browse the repository at this point in the history
### What

This is not perfect, since Sorbet hasn't been formalized.

This decodes the place-holder sorbet data we currently use in our query
results, as well as some of the Rerun-chunk metadata.

Eventually we should move this onto the rust-side of things, but as this
is largely just metadata processing, doing it in python is not terrible.
  • Loading branch information
jleibs authored Dec 16, 2024
1 parent 947d708 commit 8466c45
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 1 deletion.
15 changes: 15 additions & 0 deletions crates/store/re_chunk_store/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ impl Ord for TimeColumnDescriptor {
}

impl TimeColumnDescriptor {
fn metadata(&self) -> arrow2::datatypes::Metadata {
let Self {
timeline,
datatype: _,
} = self;

std::iter::once(Some((
"sorbet.index_name".to_owned(),
timeline.name().to_string(),
)))
.flatten()
.collect()
}

#[inline]
// Time column must be nullable since static data doesn't have a time.
pub fn to_arrow_field(&self) -> Arrow2Field {
Expand All @@ -113,6 +127,7 @@ impl TimeColumnDescriptor {
datatype.clone(),
true, /* nullable */
)
.with_metadata(self.metadata())
}
}

Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/any_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from rerun._baseclasses import ComponentDescriptor

from . import ComponentColumn
from ._baseclasses import ComponentColumn
from ._log import AsComponents, ComponentBatchLike
from .error_utils import catch_and_log_exceptions

Expand Down
100 changes: 100 additions & 0 deletions rerun_py/rerun_sdk/rerun/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Optional

import pyarrow as pa
from rerun_bindings import (
ComponentColumnDescriptor as ComponentColumnDescriptor,
ComponentColumnSelector as ComponentColumnSelector,
Expand All @@ -18,3 +22,99 @@
ComponentLike as ComponentLike,
ViewContentsLike as ViewContentsLike,
)

from ._baseclasses import ComponentColumn, ComponentDescriptor
from ._log import IndicatorComponentBatch
from ._send_columns import TimeColumnLike, send_columns
from .recording_stream import RecordingStream

SORBET_INDEX_NAME = b"sorbet.index_name"
SORBET_ENTITY_PATH = b"sorbet.path"
SORBET_ARCHETYPE_NAME = b"sorbet.semantic_family"
SORBET_ARCHETYPE_FIELD = b"sorbet.logical_type"
SORBET_COMPONENT_NAME = b"sorbet.semantic_type"
RERUN_KIND = b"rerun.kind"
RERUN_KIND_CONTROL = b"control"
RERUN_KIND_INDEX = b"time"


class RawIndexColumn(TimeColumnLike):
def __init__(self, metadata: dict[bytes, bytes], col: pa.Array):
self.metadata = metadata
self.col = col

def timeline_name(self) -> str:
name = self.metadata.get(SORBET_INDEX_NAME, "unknown")
if isinstance(name, bytes):
name = name.decode("utf-8")
return name

def as_arrow_array(self) -> pa.Array:
return self.col


class RawComponentBatchLike(ComponentColumn):
def __init__(self, metadata: dict[bytes, bytes], col: pa.Array):
self.metadata = metadata
self.col = col

def component_descriptor(self) -> ComponentDescriptor:
kwargs = {}
if SORBET_ARCHETYPE_NAME in self.metadata:
kwargs["archetype_name"] = "rerun.archetypes" + self.metadata[SORBET_ARCHETYPE_NAME].decode("utf-8")
if SORBET_COMPONENT_NAME in self.metadata:
kwargs["component_name"] = "rerun.components." + self.metadata[SORBET_COMPONENT_NAME].decode("utf-8")
if SORBET_ARCHETYPE_FIELD in self.metadata:
kwargs["archetype_field_name"] = self.metadata[SORBET_ARCHETYPE_FIELD].decode("utf-8")

if "component_name" not in kwargs:
kwargs["component_name"] = "Unknown"

return ComponentDescriptor(**kwargs)

def as_arrow_array(self) -> pa.Array:
return self.col


def send_record_batch(batch: pa.RecordBatch, rec: Optional[RecordingStream] = None) -> None:
"""Coerce a single pyarrow `RecordBatch` to Rerun structure."""

indexes = []
data: defaultdict[str, list[Any]] = defaultdict(list)
archetypes: defaultdict[str, set[Any]] = defaultdict(set)
for col in batch.schema:
metadata = col.metadata or {}
if metadata.get(RERUN_KIND) == RERUN_KIND_CONTROL:
continue
if SORBET_INDEX_NAME in metadata or metadata.get(RERUN_KIND) == RERUN_KIND_INDEX:
if SORBET_INDEX_NAME not in metadata:
metadata[SORBET_INDEX_NAME] = col.name
indexes.append(RawIndexColumn(metadata, batch.column(col.name)))
else:
entity_path = metadata.get(SORBET_ENTITY_PATH, col.name.split(":")[0])
if isinstance(entity_path, bytes):
entity_path = entity_path.decode("utf-8")
data[entity_path].append(RawComponentBatchLike(metadata, batch.column(col.name)))
if SORBET_ARCHETYPE_NAME in metadata:
archetypes[entity_path].add(metadata[SORBET_ARCHETYPE_NAME].decode("utf-8"))
for entity_path, archetype_set in archetypes.items():
for archetype in archetype_set:
data[entity_path].append(IndicatorComponentBatch("rerun.archetypes." + archetype))

for entity_path, columns in data.items():
send_columns(
entity_path,
indexes,
columns,
# This is fine, send_columns will handle the conversion
recording=rec, # NOLINT
)


def send_dataframe(df: pa.RecordBatchReader | pa.Table, rec: Optional[RecordingStream] = None) -> None:
"""Coerce a pyarrow `RecordBatchReader` or `Table` to Rerun structure."""
if isinstance(df, pa.Table):
df = df.to_reader()

for batch in df:
send_record_batch(batch, rec)
16 changes: 16 additions & 0 deletions rerun_py/tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,19 @@ def test_view_syntax(self) -> None:
table = pa.Table.from_batches(batches, batches.schema)
assert table.num_columns == 3
assert table.num_rows == 0

def test_roundtrip_send(self) -> None:
df = self.recording.view(index="my_index", contents="/**").select().read_all()

with tempfile.TemporaryDirectory() as tmpdir:
rrd = tmpdir + "/tmp.rrd"

rr.init("rerun_example_test_recording")
rr.dataframe.send_dataframe(df)
rr.save(rrd)

round_trip_recording = rr.dataframe.load_recording(rrd)

df_round_trip = round_trip_recording.view(index="my_index", contents="/**").select().read_all()

assert df == df_round_trip

0 comments on commit 8466c45

Please sign in to comment.