Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Arrow PyCapsule Interface for export #786

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals.dependencies import get_polars
from narwhals.schema import Schema
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from io import BytesIO
Expand Down Expand Up @@ -259,6 +260,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
native_frame = self._compliant_frame._native_frame
if hasattr(native_frame, "__arrow_c_stream__"):
return native_frame.__arrow_c_stream__(requested_schema=requested_schema)
try:
import pyarrow as pa
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from None
pa_table = self.to_arrow()
return pa_table.__arrow_c_stream__(requested_schema=requested_schema)

Expand Down
11 changes: 9 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Sequence
from typing import overload

from narwhals.dependencies import get_pyarrow
from narwhals.utils import parse_version

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -71,7 +71,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
native_series = self._compliant_series._native_series
if hasattr(native_series, "__arrow_c_stream__"):
return native_series.__arrow_c_stream__(requested_schema=requested_schema)
pa = get_pyarrow()
try:
import pyarrow as pa
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg)
ca = pa.chunked_array([self.to_arrow()])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might require pyarrow 15

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in pandas the requirement is PyArrow 14+ (I also just ran the tests with pyarrow 13 and 14 - the former fails, the latter passes)

Copy link
Member Author

@MarcoGorelli MarcoGorelli Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry, that's for DataFrame. looks like it's even PyArrow 16+ for chunkedarray?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added to pa.chunked_array in a later release, yes. I think it was here: apache/arrow#40818

return ca.__arrow_c_stream__(requested_schema=requested_schema)

Expand Down
Loading