diff --git a/narwhals/functions.py b/narwhals/functions.py index 41cb671fd..75cbe96c6 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -942,12 +942,16 @@ def read_csv( source: str, *, native_namespace: ModuleType, + **kwargs: Any, ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. Arguments: source: Path to a file. native_namespace: The native library to use for DataFrame creation. + kwargs: Extra keyword arguments which are passed to the native CSV reader. + For example, you could use + `nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`. Returns: DataFrame. @@ -991,13 +995,11 @@ def read_csv( a: [[1,2,3]] b: [[4,5,6]] """ - return _read_csv_impl(source, native_namespace=native_namespace) + return _read_csv_impl(source, native_namespace=native_namespace, **kwargs) def _read_csv_impl( - source: str, - *, - native_namespace: ModuleType, + source: str, *, native_namespace: ModuleType, **kwargs: Any ) -> DataFrame[Any]: implementation = Implementation.from_native_namespace(native_namespace) if implementation in ( @@ -1006,16 +1008,16 @@ def _read_csv_impl( Implementation.MODIN, Implementation.CUDF, ): - native_frame = native_namespace.read_csv(source) + native_frame = native_namespace.read_csv(source, **kwargs) elif implementation is Implementation.PYARROW: from pyarrow import csv # ignore-banned-import - native_frame = csv.read_csv(source) + native_frame = csv.read_csv(source, **kwargs) else: # pragma: no cover try: # implementation is UNKNOWN, Narwhals extension using this feature should # implement `read_csv` function in the top-level namespace. - native_frame = native_namespace.read_csv(source=source) + native_frame = native_namespace.read_csv(source=source, **kwargs) except AttributeError as e: msg = "Unknown namespace is expected to implement `read_csv` function." raise AttributeError(msg) from e @@ -1023,9 +1025,7 @@ def _read_csv_impl( def scan_csv( - source: str, - *, - native_namespace: ModuleType, + source: str, *, native_namespace: ModuleType, **kwargs: Any ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -1035,6 +1035,9 @@ def scan_csv( Arguments: source: Path to a file. native_namespace: The native library to use for DataFrame creation. + kwargs: Extra keyword arguments which are passed to the native CSV reader. + For example, you could use + `nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`. Returns: LazyFrame. @@ -1071,34 +1074,32 @@ def scan_csv( 1 2 5 2 3 6 """ - return _scan_csv_impl(source, native_namespace=native_namespace) + return _scan_csv_impl(source, native_namespace=native_namespace, **kwargs) def _scan_csv_impl( - source: str, - *, - native_namespace: ModuleType, + source: str, *, native_namespace: ModuleType, **kwargs: Any ) -> LazyFrame[Any]: implementation = Implementation.from_native_namespace(native_namespace) if implementation is Implementation.POLARS: - native_frame = native_namespace.scan_csv(source) + native_frame = native_namespace.scan_csv(source, **kwargs) elif implementation in ( Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF, ): - native_frame = native_namespace.read_csv(source) + native_frame = native_namespace.read_csv(source, **kwargs) elif implementation is Implementation.PYARROW: from pyarrow import csv # ignore-banned-import - native_frame = csv.read_csv(source) + native_frame = csv.read_csv(source, **kwargs) elif implementation is Implementation.DASK: - native_frame = native_namespace.read_csv(source) + native_frame = native_namespace.read_csv(source, **kwargs) else: # pragma: no cover try: # implementation is UNKNOWN, Narwhals extension using this feature should # implement `scan_csv` function in the top-level namespace. - native_frame = native_namespace.scan_csv(source=source) + native_frame = native_namespace.scan_csv(source=source, **kwargs) except AttributeError as e: msg = "Unknown namespace is expected to implement `scan_csv` function." raise AttributeError(msg) from e diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 7518a72ee..3733d6e03 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -3388,15 +3388,16 @@ def from_numpy( def read_csv( - source: str, - *, - native_namespace: ModuleType, + source: str, *, native_namespace: ModuleType, **kwargs: Any ) -> DataFrame[Any]: """Read a CSV file into a DataFrame. Arguments: source: Path to a file. native_namespace: The native library to use for DataFrame creation. + kwargs: Extra keyword arguments which are passed to the native CSV reader. + For example, you could use + `nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`. Returns: DataFrame. @@ -3441,17 +3442,12 @@ def read_csv( b: [[4,5,6]] """ return _stableify( # type: ignore[no-any-return] - _read_csv_impl( - source, - native_namespace=native_namespace, - ) + _read_csv_impl(source, native_namespace=native_namespace, **kwargs) ) def scan_csv( - source: str, - *, - native_namespace: ModuleType, + source: str, *, native_namespace: ModuleType, **kwargs: Any ) -> LazyFrame[Any]: """Lazily read from a CSV file. @@ -3461,6 +3457,9 @@ def scan_csv( Arguments: source: Path to a file. native_namespace: The native library to use for DataFrame creation. + kwargs: Extra keyword arguments which are passed to the native CSV reader. + For example, you could use + `nw.read_csv('file.csv', native_namespace=pd, engine='pyarrow')`. Returns: LazyFrame. @@ -3498,7 +3497,7 @@ def scan_csv( 2 3 6 """ return _stableify( # type: ignore[no-any-return] - _scan_csv_impl(source, native_namespace=native_namespace) + _scan_csv_impl(source, native_namespace=native_namespace, **kwargs) ) diff --git a/tests/read_csv_test.py b/tests/read_csv_test.py index 528a8434e..fcb755044 100644 --- a/tests/read_csv_test.py +++ b/tests/read_csv_test.py @@ -1,17 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +import pandas as pd import polars as pl +import pytest import narwhals as nw import narwhals.stable.v1 as nw_v1 +from tests.utils import PANDAS_VERSION from tests.utils import ConstructorEager from tests.utils import assert_equal_data -if TYPE_CHECKING: - import pytest - data = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]} @@ -40,3 +38,12 @@ def test_read_csv_v1( result = nw_v1.read_csv(filepath, native_namespace=native_namespace) assert_equal_data(result, data) assert isinstance(result, nw_v1.DataFrame) + + +@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") +def test_read_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None: + df_pl = pl.DataFrame(data) + filepath = str(tmpdir / "file.csv") # type: ignore[operator] + df_pl.write_csv(filepath) + result = nw.read_csv(filepath, native_namespace=pd, engine="pyarrow") + assert_equal_data(result, data) diff --git a/tests/scan_csv_test.py b/tests/scan_csv_test.py index 4f0083667..85b3c26ce 100644 --- a/tests/scan_csv_test.py +++ b/tests/scan_csv_test.py @@ -1,17 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +import pandas as pd import polars as pl +import pytest import narwhals as nw import narwhals.stable.v1 as nw_v1 +from tests.utils import PANDAS_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data -if TYPE_CHECKING: - import pytest - data = {"a": [1, 2, 3], "b": [4.5, 6.7, 8.9], "z": ["x", "y", "w"]} @@ -41,3 +39,12 @@ def test_scan_csv_v1( result = nw_v1.scan_csv(filepath, native_namespace=native_namespace) assert_equal_data(result.collect(), data) assert isinstance(result, nw_v1.LazyFrame) + + +@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") +def test_scan_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None: + df_pl = pl.DataFrame(data) + filepath = str(tmpdir / "file.csv") # type: ignore[operator] + df_pl.write_csv(filepath) + result = nw.scan_csv(filepath, native_namespace=pd, engine="pyarrow") + assert_equal_data(result, data)