Skip to content

Commit

Permalink
feat: add **kwargs to read_csv and scan_csv (#1560)
Browse files Browse the repository at this point in the history
* feat: add kwargs to read_csv and scan_csv

* reword kwargs, add an example

* pass kwargs for unlisted namespace, add test for kwargs engine

* xfail kwargs test for old pandas versions
  • Loading branch information
raisadz authored Dec 11, 2024
1 parent a2088f4 commit f483eaa
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 40 deletions.
39 changes: 20 additions & 19 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,12 +942,16 @@ def read_csv(
source: str,
*,
native_namespace: ModuleType,
**kwargs: Any,
) -> DataFrame[Any]:
"""Read a CSV file into a DataFrame.
Arguments:
source: Path to a file.
native_namespace: The native library to use for DataFrame creation.
kwargs: Extra keyword arguments which are passed to the native CSV reader.
For example, you could use
`nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`.
Returns:
DataFrame.
Expand Down Expand Up @@ -991,13 +995,11 @@ def read_csv(
a: [[1,2,3]]
b: [[4,5,6]]
"""
return _read_csv_impl(source, native_namespace=native_namespace)
return _read_csv_impl(source, native_namespace=native_namespace, **kwargs)


def _read_csv_impl(
source: str,
*,
native_namespace: ModuleType,
source: str, *, native_namespace: ModuleType, **kwargs: Any
) -> DataFrame[Any]:
implementation = Implementation.from_native_namespace(native_namespace)
if implementation in (
Expand All @@ -1006,26 +1008,24 @@ def _read_csv_impl(
Implementation.MODIN,
Implementation.CUDF,
):
native_frame = native_namespace.read_csv(source)
native_frame = native_namespace.read_csv(source, **kwargs)
elif implementation is Implementation.PYARROW:
from pyarrow import csv # ignore-banned-import

native_frame = csv.read_csv(source)
native_frame = csv.read_csv(source, **kwargs)
else: # pragma: no cover
try:
# implementation is UNKNOWN, Narwhals extension using this feature should
# implement `read_csv` function in the top-level namespace.
native_frame = native_namespace.read_csv(source=source)
native_frame = native_namespace.read_csv(source=source, **kwargs)
except AttributeError as e:
msg = "Unknown namespace is expected to implement `read_csv` function."
raise AttributeError(msg) from e
return from_native(native_frame, eager_only=True)


def scan_csv(
source: str,
*,
native_namespace: ModuleType,
source: str, *, native_namespace: ModuleType, **kwargs: Any
) -> LazyFrame[Any]:
"""Lazily read from a CSV file.
Expand All @@ -1035,6 +1035,9 @@ def scan_csv(
Arguments:
source: Path to a file.
native_namespace: The native library to use for DataFrame creation.
kwargs: Extra keyword arguments which are passed to the native CSV reader.
For example, you could use
`nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`.
Returns:
LazyFrame.
Expand Down Expand Up @@ -1071,34 +1074,32 @@ def scan_csv(
1 2 5
2 3 6
"""
return _scan_csv_impl(source, native_namespace=native_namespace)
return _scan_csv_impl(source, native_namespace=native_namespace, **kwargs)


def _scan_csv_impl(
source: str,
*,
native_namespace: ModuleType,
source: str, *, native_namespace: ModuleType, **kwargs: Any
) -> LazyFrame[Any]:
implementation = Implementation.from_native_namespace(native_namespace)
if implementation is Implementation.POLARS:
native_frame = native_namespace.scan_csv(source)
native_frame = native_namespace.scan_csv(source, **kwargs)
elif implementation in (
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
):
native_frame = native_namespace.read_csv(source)
native_frame = native_namespace.read_csv(source, **kwargs)
elif implementation is Implementation.PYARROW:
from pyarrow import csv # ignore-banned-import

native_frame = csv.read_csv(source)
native_frame = csv.read_csv(source, **kwargs)
elif implementation is Implementation.DASK:
native_frame = native_namespace.read_csv(source)
native_frame = native_namespace.read_csv(source, **kwargs)
else: # pragma: no cover
try:
# implementation is UNKNOWN, Narwhals extension using this feature should
# implement `scan_csv` function in the top-level namespace.
native_frame = native_namespace.scan_csv(source=source)
native_frame = native_namespace.scan_csv(source=source, **kwargs)
except AttributeError as e:
msg = "Unknown namespace is expected to implement `scan_csv` function."
raise AttributeError(msg) from e
Expand Down
21 changes: 10 additions & 11 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,15 +3388,16 @@ def from_numpy(


def read_csv(
source: str,
*,
native_namespace: ModuleType,
source: str, *, native_namespace: ModuleType, **kwargs: Any
) -> DataFrame[Any]:
"""Read a CSV file into a DataFrame.
Arguments:
source: Path to a file.
native_namespace: The native library to use for DataFrame creation.
kwargs: Extra keyword arguments which are passed to the native CSV reader.
For example, you could use
`nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`.
Returns:
DataFrame.
Expand Down Expand Up @@ -3441,17 +3442,12 @@ def read_csv(
b: [[4,5,6]]
"""
return _stableify( # type: ignore[no-any-return]
_read_csv_impl(
source,
native_namespace=native_namespace,
)
_read_csv_impl(source, native_namespace=native_namespace, **kwargs)
)


def scan_csv(
source: str,
*,
native_namespace: ModuleType,
source: str, *, native_namespace: ModuleType, **kwargs: Any
) -> LazyFrame[Any]:
"""Lazily read from a CSV file.
Expand All @@ -3461,6 +3457,9 @@ def scan_csv(
Arguments:
source: Path to a file.
native_namespace: The native library to use for DataFrame creation.
kwargs: Extra keyword arguments which are passed to the native CSV reader.
For example, you could use
`nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`.
Returns:
LazyFrame.
Expand Down Expand Up @@ -3498,7 +3497,7 @@ def scan_csv(
2 3 6
"""
return _stableify( # type: ignore[no-any-return]
_scan_csv_impl(source, native_namespace=native_namespace)
_scan_csv_impl(source, native_namespace=native_namespace, **kwargs)
)


Expand Down
17 changes: 12 additions & 5 deletions tests/read_csv_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
import narwhals.stable.v1 as nw_v1
from tests.utils import PANDAS_VERSION
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

if TYPE_CHECKING:
import pytest

data = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]}


Expand Down Expand Up @@ -40,3 +38,12 @@ def test_read_csv_v1(
result = nw_v1.read_csv(filepath, native_namespace=native_namespace)
assert_equal_data(result, data)
assert isinstance(result, nw_v1.DataFrame)


@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")
def test_read_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None:
df_pl = pl.DataFrame(data)
filepath = str(tmpdir / "file.csv") # type: ignore[operator]
df_pl.write_csv(filepath)
result = nw.read_csv(filepath, native_namespace=pd, engine="pyarrow")
assert_equal_data(result, data)
17 changes: 12 additions & 5 deletions tests/scan_csv_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
import narwhals.stable.v1 as nw_v1
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

if TYPE_CHECKING:
import pytest

data = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]}


Expand Down Expand Up @@ -41,3 +39,12 @@ def test_scan_csv_v1(
result = nw_v1.scan_csv(filepath, native_namespace=native_namespace)
assert_equal_data(result.collect(), data)
assert isinstance(result, nw_v1.LazyFrame)


@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")
def test_scan_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None:
df_pl = pl.DataFrame(data)
filepath = str(tmpdir / "file.csv") # type: ignore[operator]
df_pl.write_csv(filepath)
result = nw.scan_csv(filepath, native_namespace=pd, engine="pyarrow")
assert_equal_data(result, data)

0 comments on commit f483eaa

Please sign in to comment.