From 9c3b9a1424422db2da05082cf8a54ac29ded62dd Mon Sep 17 00:00:00 2001 From: Antoine Beyeler Date: Fri, 11 Oct 2024 11:24:00 +0200 Subject: [PATCH] Add support for NumPy arrays to the `Utf8` datatype arrow serializer --- rerun_py/rerun_sdk/rerun/datatypes/utf8.py | 12 +++------- .../rerun_sdk/rerun/datatypes/utf8_ext.py | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) create mode 100644 rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py diff --git a/rerun_py/rerun_sdk/rerun/datatypes/utf8.py b/rerun_py/rerun_sdk/rerun/datatypes/utf8.py index 6d51ee43b1ff2..91e73e5d4a07d 100644 --- a/rerun_py/rerun_sdk/rerun/datatypes/utf8.py +++ b/rerun_py/rerun_sdk/rerun/datatypes/utf8.py @@ -14,12 +14,13 @@ BaseBatch, BaseExtensionType, ) +from .utf8_ext import Utf8Ext __all__ = ["Utf8", "Utf8ArrayLike", "Utf8Batch", "Utf8Like", "Utf8Type"] @define(init=False) -class Utf8: +class Utf8(Utf8Ext): """**Datatype**: A string of text, encoded as UTF-8.""" def __init__(self: Any, value: Utf8Like): @@ -57,11 +58,4 @@ class Utf8Batch(BaseBatch[Utf8ArrayLike]): @staticmethod def _native_to_pa_array(data: Utf8ArrayLike, data_type: pa.DataType) -> pa.Array: - if isinstance(data, str): - array = [data] - elif isinstance(data, Sequence): - array = [str(datum) for datum in data] - else: - array = [str(data)] - - return pa.array(array, type=data_type) + return Utf8Ext.native_to_pa_array_override(data, data_type) diff --git a/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py b/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py new file mode 100644 index 0000000000000..7972a0bab097a --- /dev/null +++ b/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +import numpy as np +import pyarrow as pa + +if TYPE_CHECKING: + from . import Utf8ArrayLike + + +class Utf8Ext: + @staticmethod + def native_to_pa_array_override(data: Utf8ArrayLike, data_type: pa.DataType) -> pa.Array: + if isinstance(data, str): + array = [data] + elif isinstance(data, Sequence): + array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data + else: + array = [str(data)] + + return pa.array(array, type=data_type)