Skip to content

Commit

Permalink
perf: use lru_cache for native_to_narwhals_dtype (#1564)
Browse files Browse the repository at this point in the history
* perf: use lru_cache for native_to_narwhals_dtype

* cleanups

* cleanups
  • Loading branch information
MarcoGorelli authored Dec 12, 2024
1 parent 1a53db3 commit d94775a
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 58 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def write_parquet(self: Self, file: Any) -> None:

def write_csv(self: Self, file: Any) -> Any:
import pyarrow as pa
import pyarrow.csv as pa_csv # ignore-banned-import
import pyarrow.csv as pa_csv

pa_table = self._native_frame
if file is None:
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,8 +1471,8 @@ def __init__(self: Self, series: ArrowSeries) -> None:
self._arrow_series = series

def len(self: Self) -> ArrowSeries:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()
import pyarrow as pa
import pyarrow.compute as pc

return self._arrow_series._from_native_series(
pc.cast(pc.list_value_length(self._arrow_series._native_series), pa.uint32())
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
Expand All @@ -17,6 +18,7 @@
from narwhals.utils import Version


@lru_cache(maxsize=16)
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType:
import pyarrow as pa

Expand Down
15 changes: 7 additions & 8 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any

Expand All @@ -20,8 +21,8 @@
from narwhals.utils import Version


def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, version: Version) -> DType:
duckdb_dtype = str(duckdb_dtype)
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if duckdb_dtype == "BIGINT":
return dtypes.Int64()
Expand Down Expand Up @@ -59,16 +60,16 @@ def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, version: Version) -> D
[
dtypes.Field(
matchstruc_[i][0],
map_duckdb_dtype_to_narwhals_dtype(matchstruc_[i][1], version),
native_to_narwhals_dtype(matchstruc_[i][1], version),
)
for i in range(len(matchstruc_))
]
)
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(map_duckdb_dtype_to_narwhals_dtype(match_.group(1), version))
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version))
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype):
return dtypes.Array(
map_duckdb_dtype_to_narwhals_dtype(match_.group(1), version),
native_to_narwhals_dtype(match_.group(1), version),
int(match_.group(2)),
)
return dtypes.Unknown()
Expand Down Expand Up @@ -111,9 +112,7 @@ def select(
def __getattr__(self, attr: str) -> Any:
if attr == "schema":
return {
column_name: map_duckdb_dtype_to_narwhals_dtype(
duckdb_dtype, self._version
)
column_name: native_to_narwhals_dtype(str(duckdb_dtype), self._version)
for column_name, duckdb_dtype in zip(
self._native_frame.columns, self._native_frame.types
)
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_duckdb/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any

from narwhals._duckdb.dataframe import map_duckdb_dtype_to_narwhals_dtype
from narwhals._duckdb.dataframe import native_to_narwhals_dtype
from narwhals.dependencies import get_duckdb

if TYPE_CHECKING:
Expand All @@ -25,8 +25,8 @@ def __native_namespace__(self) -> ModuleType:

def __getattr__(self, attr: str) -> Any:
if attr == "dtype":
return map_duckdb_dtype_to_narwhals_dtype(
self._native_series.types[0], self._version
return native_to_narwhals_dtype(
str(self._native_series.types[0]), self._version
)
msg = ( # pragma: no cover
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
Expand Down
12 changes: 6 additions & 6 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any

Expand All @@ -18,7 +19,8 @@
from narwhals.utils import Version


def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType:
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if ibis_dtype.is_int64():
return dtypes.Int64()
Expand Down Expand Up @@ -49,15 +51,13 @@ def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, version: Version) -> DType
if ibis_dtype.is_timestamp():
return dtypes.Datetime()
if ibis_dtype.is_array():
return dtypes.List(
map_ibis_dtype_to_narwhals_dtype(ibis_dtype.value_type, version)
)
return dtypes.List(native_to_narwhals_dtype(ibis_dtype.value_type, version))
if ibis_dtype.is_struct():
return dtypes.Struct(
[
dtypes.Field(
ibis_dtype_name,
map_ibis_dtype_to_narwhals_dtype(ibis_dtype_field, version),
native_to_narwhals_dtype(ibis_dtype_field, version),
)
for ibis_dtype_name, ibis_dtype_field in ibis_dtype.items()
]
Expand Down Expand Up @@ -108,7 +108,7 @@ def select(
def __getattr__(self, attr: str) -> Any:
if attr == "schema":
return {
column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype, self._version)
column_name: native_to_narwhals_dtype(ibis_dtype, self._version)
for column_name, ibis_dtype in self._native_frame.schema().items()
}
elif attr == "columns":
Expand Down
6 changes: 2 additions & 4 deletions narwhals/_ibis/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any

from narwhals._ibis.dataframe import map_ibis_dtype_to_narwhals_dtype
from narwhals._ibis.dataframe import native_to_narwhals_dtype
from narwhals.dependencies import get_ibis

if TYPE_CHECKING:
Expand All @@ -25,9 +25,7 @@ def __native_namespace__(self) -> ModuleType:

def __getattr__(self, attr: str) -> Any:
if attr == "dtype":
return map_ibis_dtype_to_narwhals_dtype(
self._native_series.type(), self._version
)
return native_to_narwhals_dtype(self._native_series.type(), self._version)
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"If you would like to see this kind of object better supported in "
Expand Down
74 changes: 42 additions & 32 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand Down Expand Up @@ -342,13 +343,11 @@ def rename(
return obj.rename(*args, **kwargs, copy=False) # type: ignore[attr-defined, no-any-return]


def native_to_narwhals_dtype(
native_column: Any, version: Version, implementation: Implementation
@lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(
dtype: str, version: Version, _implementation: Implementation
) -> DType:
dtype = str(native_column.dtype)

dtypes = import_dtypes_module(version)

if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
return dtypes.Int64()
if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}:
Expand Down Expand Up @@ -400,36 +399,47 @@ def native_to_narwhals_dtype(
return dtypes.Duration(du_time_unit)
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
return dtypes.Unknown() # pragma: no cover


def native_to_narwhals_dtype(
native_column: Any, version: Version, implementation: Implementation
) -> DType:
dtype = str(native_column.dtype)

dtypes = import_dtypes_module(version)

if dtype.startswith(("large_list", "list", "struct", "fixed_size_list")):
return arrow_native_to_narwhals_dtype(native_column.dtype.pyarrow_dtype, version)
if dtype == "object":
if implementation is Implementation.DASK:
# Dask columns are lazy, so we can't inspect values.
# The most useful assumption is probably String
if dtype != "object":
return non_object_native_to_narwhals_dtype(dtype, version, implementation)
if implementation is Implementation.DASK:
# Dask columns are lazy, so we can't inspect values.
# The most useful assumption is probably String
return dtypes.String()
if implementation is Implementation.PANDAS: # pragma: no cover
# This is the most efficient implementation for pandas,
# and doesn't require the interchange protocol
import pandas as pd

dtype = pd.api.types.infer_dtype(native_column, skipna=True)
if dtype == "string":
return dtypes.String()
if implementation is Implementation.PANDAS: # pragma: no cover
# This is the most efficient implementation for pandas,
# and doesn't require the interchange protocol
import pandas as pd

dtype = pd.api.types.infer_dtype(native_column, skipna=True)
if dtype == "string":
return dtypes.String()
return dtypes.Object()
else: # pragma: no cover
df = native_column.to_frame()
if hasattr(df, "__dataframe__"):
from narwhals._interchange.dataframe import (
map_interchange_dtype_to_narwhals_dtype,
)
return dtypes.Object()
else: # pragma: no cover
df = native_column.to_frame()
if hasattr(df, "__dataframe__"):
from narwhals._interchange.dataframe import (
map_interchange_dtype_to_narwhals_dtype,
)

try:
return map_interchange_dtype_to_narwhals_dtype(
df.__dataframe__().get_column(0).dtype, version
)
except Exception: # noqa: BLE001, S110
pass
return dtypes.Unknown()
try:
return map_interchange_dtype_to_narwhals_dtype(
df.__dataframe__().get_column(0).dtype, version
)
except Exception: # noqa: BLE001, S110
pass
return dtypes.Unknown() # pragma: no cover


def get_dtype_backend(dtype: Any, implementation: Implementation) -> str:
Expand Down Expand Up @@ -588,7 +598,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915
if isinstance_or_issubclass(dtype, dtypes.List):
if implementation is Implementation.PANDAS and backend_version >= (2, 2):
try:
import pandas as pd # ignore-banned-import
import pandas as pd
import pyarrow as pa # ignore-banned-import
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
Expand Down Expand Up @@ -64,6 +65,7 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An
}


@lru_cache(maxsize=16)
def native_to_narwhals_dtype(
dtype: pl.DataType,
version: Version,
Expand Down Expand Up @@ -104,11 +106,11 @@ def native_to_narwhals_dtype(
return dtypes.Enum()
if dtype == pl.Date:
return dtypes.Date()
if dtype == pl.Datetime or isinstance(dtype, pl.Datetime):
if dtype == pl.Datetime:
dt_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone)
if dtype == pl.Duration or isinstance(dtype, pl.Duration):
if dtype == pl.Duration:
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return dtypes.Duration(time_unit=du_time_unit)
if dtype == pl.Struct:
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any

Expand All @@ -16,6 +17,7 @@
from narwhals.utils import Version


@lru_cache(maxsize=16)
def native_to_narwhals_dtype(
dtype: pyspark_types.DataType,
version: Version,
Expand Down

0 comments on commit d94775a

Please sign in to comment.