diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index fe0201a90..ea6ed4697 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -584,7 +584,7 @@ def write_parquet(self: Self, file: Any) -> None: def write_csv(self: Self, file: Any) -> Any: import pyarrow as pa - import pyarrow.csv as pa_csv # ignore-banned-import + import pyarrow.csv as pa_csv pa_table = self._native_frame if file is None: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index c91e3ff40..ae964cb84 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -1471,8 +1471,8 @@ def __init__(self: Self, series: ArrowSeries) -> None: self._arrow_series = series def len(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.cast(pc.list_value_length(self._arrow_series._native_series), pa.uint32()) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 2d237e596..201786723 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any from typing import Sequence @@ -17,6 +18,7 @@ from narwhals.utils import Version +@lru_cache(maxsize=16) def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: import pyarrow as pa diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index fd1c436e8..65993e280 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any @@ -20,8 +21,8 @@ from narwhals.utils import Version -def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, version: Version) -> DType: - duckdb_dtype = str(duckdb_dtype) +@lru_cache(maxsize=16) +def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType: dtypes = import_dtypes_module(version) if duckdb_dtype == "BIGINT": return dtypes.Int64() @@ -59,16 +60,16 @@ def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, version: Version) -> D [ dtypes.Field( matchstruc_[i][0], - map_duckdb_dtype_to_narwhals_dtype(matchstruc_[i][1], version), + native_to_narwhals_dtype(matchstruc_[i][1], version), ) for i in range(len(matchstruc_)) ] ) if match_ := re.match(r"(.*)\[\]$", duckdb_dtype): - return dtypes.List(map_duckdb_dtype_to_narwhals_dtype(match_.group(1), version)) + return dtypes.List(native_to_narwhals_dtype(match_.group(1), version)) if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype): return dtypes.Array( - map_duckdb_dtype_to_narwhals_dtype(match_.group(1), version), + native_to_narwhals_dtype(match_.group(1), version), int(match_.group(2)), ) return dtypes.Unknown() @@ -111,9 +112,7 @@ def select( def __getattr__(self, attr: str) -> Any: if attr == "schema": return { - column_name: map_duckdb_dtype_to_narwhals_dtype( - duckdb_dtype, self._version - ) + column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version) for column_name, duckdb_dtype in zip( self._native_frame.columns, self._native_frame.types ) diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py index c0340bffd..dc7485e98 100644 --- a/narwhals/_duckdb/series.py +++ b/narwhals/_duckdb/series.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from typing import Any -from narwhals._duckdb.dataframe import map_duckdb_dtype_to_narwhals_dtype +from narwhals._duckdb.dataframe import native_to_narwhals_dtype from narwhals.dependencies import get_duckdb if TYPE_CHECKING: @@ -25,8 +25,8 @@ def __native_namespace__(self) -> ModuleType: def __getattr__(self, attr: str) -> Any: if attr == "dtype": - return map_duckdb_dtype_to_narwhals_dtype( - self._native_series.types[0], self._version + return native_to_narwhals_dtype( + str(self._native_series.types[0]), self._version ) msg = ( # pragma: no cover f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 454354a7e..62c5f7a18 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any @@ -18,7 +19,8 @@ from narwhals.utils import Version -def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType: +@lru_cache(maxsize=16) +def native_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType: dtypes = import_dtypes_module(version) if ibis_dtype.is_int64(): return dtypes.Int64() @@ -49,15 +51,13 @@ def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType if ibis_dtype.is_timestamp(): return dtypes.Datetime() if ibis_dtype.is_array(): - return dtypes.List( - map_ibis_dtype_to_narwhals_dtype(ibis_dtype.value_type, version) - ) + return dtypes.List(native_to_narwhals_dtype(ibis_dtype.value_type, version)) if ibis_dtype.is_struct(): return dtypes.Struct( [ dtypes.Field( ibis_dtype_name, - map_ibis_dtype_to_narwhals_dtype(ibis_dtype_field, version), + native_to_narwhals_dtype(ibis_dtype_field, version), ) for ibis_dtype_name, ibis_dtype_field in ibis_dtype.items() ] @@ -108,7 +108,7 @@ def select( def __getattr__(self, attr: str) -> Any: if attr == "schema": return { - column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype, self._version) + column_name: native_to_narwhals_dtype(ibis_dtype, self._version) for column_name, ibis_dtype in self._native_frame.schema().items() } elif attr == "columns": diff --git a/narwhals/_ibis/series.py b/narwhals/_ibis/series.py index 925ea40a1..5629c0cfc 100644 --- a/narwhals/_ibis/series.py +++ b/narwhals/_ibis/series.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from typing import Any -from narwhals._ibis.dataframe import map_ibis_dtype_to_narwhals_dtype +from narwhals._ibis.dataframe import native_to_narwhals_dtype from narwhals.dependencies import get_ibis if TYPE_CHECKING: @@ -25,9 +25,7 @@ def __native_namespace__(self) -> ModuleType: def __getattr__(self, attr: str) -> Any: if attr == "dtype": - return map_ibis_dtype_to_narwhals_dtype( - self._native_series.type(), self._version - ) + return native_to_narwhals_dtype(self._native_series.type(), self._version) msg = ( f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "If you would like to see this kind of object better supported in " diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index c2be65264..517ba3d37 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any from typing import Iterable @@ -342,13 +343,11 @@ def rename( return obj.rename(*args, **kwargs, copy=False) # type: ignore[attr-defined, no-any-return] -def native_to_narwhals_dtype( - native_column: Any, version: Version, implementation: Implementation +@lru_cache(maxsize=16) +def non_object_native_to_narwhals_dtype( + dtype: str, version: Version, _implementation: Implementation ) -> DType: - dtype = str(native_column.dtype) - dtypes = import_dtypes_module(version) - if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}: @@ -400,36 +399,47 @@ def native_to_narwhals_dtype( return dtypes.Duration(du_time_unit) if dtype == "date32[day][pyarrow]": return dtypes.Date() + return dtypes.Unknown() # pragma: no cover + + +def native_to_narwhals_dtype( + native_column: Any, version: Version, implementation: Implementation +) -> DType: + dtype = str(native_column.dtype) + + dtypes = import_dtypes_module(version) + if dtype.startswith(("large_list", "list", "struct", "fixed_size_list")): return arrow_native_to_narwhals_dtype(native_column.dtype.pyarrow_dtype, version) - if dtype == "object": - if implementation is Implementation.DASK: - # Dask columns are lazy, so we can't inspect values. - # The most useful assumption is probably String + if dtype != "object": + return non_object_native_to_narwhals_dtype(dtype, version, implementation) + if implementation is Implementation.DASK: + # Dask columns are lazy, so we can't inspect values. + # The most useful assumption is probably String + return dtypes.String() + if implementation is Implementation.PANDAS: # pragma: no cover + # This is the most efficient implementation for pandas, + # and doesn't require the interchange protocol + import pandas as pd + + dtype = pd.api.types.infer_dtype(native_column, skipna=True) + if dtype == "string": return dtypes.String() - if implementation is Implementation.PANDAS: # pragma: no cover - # This is the most efficient implementation for pandas, - # and doesn't require the interchange protocol - import pandas as pd - - dtype = pd.api.types.infer_dtype(native_column, skipna=True) - if dtype == "string": - return dtypes.String() - return dtypes.Object() - else: # pragma: no cover - df = native_column.to_frame() - if hasattr(df, "__dataframe__"): - from narwhals._interchange.dataframe import ( - map_interchange_dtype_to_narwhals_dtype, - ) + return dtypes.Object() + else: # pragma: no cover + df = native_column.to_frame() + if hasattr(df, "__dataframe__"): + from narwhals._interchange.dataframe import ( + map_interchange_dtype_to_narwhals_dtype, + ) - try: - return map_interchange_dtype_to_narwhals_dtype( - df.__dataframe__().get_column(0).dtype, version - ) - except Exception: # noqa: BLE001, S110 - pass - return dtypes.Unknown() + try: + return map_interchange_dtype_to_narwhals_dtype( + df.__dataframe__().get_column(0).dtype, version + ) + except Exception: # noqa: BLE001, S110 + pass + return dtypes.Unknown() # pragma: no cover def get_dtype_backend(dtype: Any, implementation: Implementation) -> str: @@ -588,7 +598,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 if isinstance_or_issubclass(dtype, dtypes.List): if implementation is Implementation.PANDAS and backend_version >= (2, 2): try: - import pandas as pd # ignore-banned-import + import pandas as pd import pyarrow as pa # ignore-banned-import except ImportError as exc: # pragma: no cover msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}" diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 44f632c90..4ea1757fd 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any from typing import Literal @@ -64,6 +65,7 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An } +@lru_cache(maxsize=16) def native_to_narwhals_dtype( dtype: pl.DataType, version: Version, @@ -104,11 +106,11 @@ def native_to_narwhals_dtype( return dtypes.Enum() if dtype == pl.Date: return dtypes.Date() - if dtype == pl.Datetime or isinstance(dtype, pl.Datetime): + if dtype == pl.Datetime: dt_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") dt_time_zone = getattr(dtype, "time_zone", None) return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone) - if dtype == pl.Duration or isinstance(dtype, pl.Duration): + if dtype == pl.Duration: du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us") return dtypes.Duration(time_unit=du_time_unit) if dtype == pl.Struct: diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 4a22bff7e..d3c646a9c 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import lru_cache from typing import TYPE_CHECKING from typing import Any @@ -16,6 +17,7 @@ from narwhals.utils import Version +@lru_cache(maxsize=16) def native_to_narwhals_dtype( dtype: pyspark_types.DataType, version: Version,