Skip to content

Commit

Permalink
feat: Support Arrow PyCapsule Interface for export (#786)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 14, 2024
1 parent 6075ec7 commit 350fe7d
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
handler: python
options:
members:
- __arrow_c_stream__
- __getitem__
- clone
- collect_schema
Expand Down
2 changes: 2 additions & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
handler: python
options:
members:
- __arrow_c_stream__
- __getitem__
- abs
- alias
- all
Expand Down
25 changes: 25 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals.dependencies import is_numpy_array
from narwhals.schema import Schema
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from io import BytesIO
Expand Down Expand Up @@ -249,6 +250,30 @@ def __repr__(self) -> str: # pragma: no cover
+ "┘"
)

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""
Export a DataFrame via the Arrow PyCapsule Interface.
- if the underlying dataframe implements the interface, it'll return that
- else, it'll call `to_arrow` and then defer to PyArrow's implementation
See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html)
for more.
"""
native_frame = self._compliant_frame._native_frame
if hasattr(native_frame, "__arrow_c_stream__"):
return native_frame.__arrow_c_stream__(requested_schema=requested_schema)
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from None
pa_table = self.to_arrow()
return pa_table.__arrow_c_stream__(requested_schema=requested_schema)

def lazy(self) -> LazyFrame[Any]:
"""
Lazify the DataFrame (if possible).
Expand Down
28 changes: 28 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Sequence
from typing import overload

from narwhals.utils import parse_version

if TYPE_CHECKING:
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -57,6 +59,32 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self:
def __native_namespace__(self) -> Any:
return self._compliant_series.__native_namespace__()

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""
Export a Series via the Arrow PyCapsule Interface.
Narwhals doesn't implement anything itself here:
- if the underlying series implements the interface, it'll return that
- else, it'll call `to_arrow` and then defer to PyArrow's implementation
See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html)
for more.
"""
native_series = self._compliant_series._native_series
if hasattr(native_series, "__arrow_c_stream__"):
return native_series.__arrow_c_stream__(requested_schema=requested_schema)
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (16, 0): # pragma: no cover
msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg)
ca = pa.chunked_array([self.to_arrow()])
return ca.__arrow_c_stream__(requested_schema=requested_schema)

@property
def shape(self) -> tuple[int]:
"""
Expand Down
42 changes: 42 additions & 0 deletions tests/frame/arrow_c_stream_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test() -> None:
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
result = pa.table(df)
expected = pa.table({"a": [1, 2, 3]})
assert pc.all(pc.equal(result["a"], expected["a"])).as_py()


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
# "poison" the dunder method to make sure it actually got called above
monkeypatch.setattr(
"narwhals.dataframe.DataFrame.__arrow_c_stream__", lambda *_: 1 / 0
)
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
with pytest.raises(ZeroDivisionError, match="division by zero"):
pa.table(df)


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
# Check that fallback to PyArrow works
monkeypatch.delattr("polars.DataFrame.__arrow_c_stream__")
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
result = pa.table(df)
expected = pa.table({"a": [1, 2, 3]})
assert pc.all(pc.equal(result["a"], expected["a"])).as_py()
41 changes: 41 additions & 0 deletions tests/series_only/arrow_c_stream_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test() -> None:
s = nw.from_native(pl.Series([1, 2, 3]), series_only=True)
result = pa.chunked_array(s)
expected = pa.chunked_array([[1, 2, 3]])
assert pc.all(pc.equal(result, expected)).as_py()


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
# "poison" the dunder method to make sure it actually got called above
monkeypatch.setattr("narwhals.series.Series.__arrow_c_stream__", lambda *_: 1 / 0)
s = nw.from_native(pl.Series([1, 2, 3]), series_only=True)
with pytest.raises(ZeroDivisionError, match="division by zero"):
pa.chunked_array(s)


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
# Check that fallback to PyArrow works
monkeypatch.delattr("polars.Series.__arrow_c_stream__")
s = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)["a"]
s.__arrow_c_stream__()
result = pa.chunked_array(s)
expected = pa.chunked_array([[1, 2, 3]])
assert pc.all(pc.equal(result, expected)).as_py()
6 changes: 3 additions & 3 deletions utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@
documented = [
remove_prefix(i, " - ")
for i in content.splitlines()
if i.startswith(" - ")
if i.startswith(" - ") and not i.startswith(" - _")
]
if missing := set(top_level_functions).difference(documented):
print("DataFrame: not documented") # noqa: T201
print(missing) # noqa: T201
ret = 1
if extra := set(documented).difference(top_level_functions).difference({"__getitem__"}):
if extra := set(documented).difference(top_level_functions):
print("DataFrame: outdated") # noqa: T201
print(extra) # noqa: T201
ret = 1
Expand Down Expand Up @@ -87,7 +87,7 @@
documented = [
remove_prefix(i, " - ")
for i in content.splitlines()
if i.startswith(" - ")
if i.startswith(" - ") and not i.startswith(" - _")
]
if (
missing := set(top_level_functions)
Expand Down

0 comments on commit 350fe7d

Please sign in to comment.