Skip to content

Commit

Permalink
feat: pyspark and duckdb selectors (#1853)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jan 26, 2025
1 parent e2ba74b commit 364a625
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 14 deletions.
2 changes: 2 additions & 0 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def numeric(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype(
[
dtypes.Int128,
dtypes.Int64,
dtypes.Int32,
dtypes.Int16,
dtypes.Int8,
dtypes.UInt128,
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def numeric(self: Self) -> DaskSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype(
[
dtypes.Int128,
dtypes.Int64,
dtypes.Int32,
dtypes.Int16,
dtypes.Int8,
dtypes.UInt128,
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from duckdb import FunctionExpression

from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.selectors import DuckDBSelectorNamespace
from narwhals._duckdb.utils import narwhals_to_native_dtype
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
Expand All @@ -38,6 +39,12 @@ def __init__(
self._backend_version = backend_version
self._version = version

@property
def selectors(self: Self) -> DuckDBSelectorNamespace:
return DuckDBSelectorNamespace(
backend_version=self._backend_version, version=self._version
)

def all(self: Self) -> DuckDBExpr:
def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [ColumnExpression(col_name) for col_name in df.columns]
Expand Down
212 changes: 212 additions & 0 deletions narwhals/_duckdb/selectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence

from duckdb import ColumnExpression

from narwhals._duckdb.expr import DuckDBExpr
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
import duckdb
from typing_extensions import Self

from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals.dtypes import DType
from narwhals.utils import Version


class DuckDBSelectorNamespace:
def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
self._backend_version = backend_version
self._version = version

def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> DuckDBSelector:
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [
ColumnExpression(col) for col in df.columns if df.schema[col] in dtypes
]

def evalute_output_names(df: DuckDBLazyFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return DuckDBSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evalute_output_names,
alias_output_names=None,
backend_version=self._backend_version,
returns_scalar=False,
version=self._version,
kwargs={},
)

def numeric(self: Self) -> DuckDBSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype(
[
dtypes.Int128,
dtypes.Int64,
dtypes.Int32,
dtypes.Int16,
dtypes.Int8,
dtypes.UInt128,
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
dtypes.UInt8,
dtypes.Float64,
dtypes.Float32,
],
)

def categorical(self: Self) -> DuckDBSelector: # pragma: no cover
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.Categorical])

def string(self: Self) -> DuckDBSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.String])

def boolean(self: Self) -> DuckDBSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype([dtypes.Boolean])

def all(self: Self) -> DuckDBSelector:
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [ColumnExpression(col) for col in df.columns]

return DuckDBSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=lambda df: df.columns,
alias_output_names=None,
backend_version=self._backend_version,
returns_scalar=False,
version=self._version,
kwargs={},
)


class DuckDBSelector(DuckDBExpr):
def __repr__(self: Self) -> str: # pragma: no cover
return (
f"DuckDBSelector("
f"depth={self._depth}, "
f"function_name={self._function_name})"
)

def _to_expr(self: Self) -> DuckDBExpr:
return DuckDBExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
returns_scalar=self._returns_scalar,
version=self._version,
kwargs={},
)

def __sub__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any:
if isinstance(other, DuckDBSelector):

def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names]

def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return DuckDBSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
returns_scalar=self._returns_scalar,
version=self._version,
kwargs={},
)
else:
return self._to_expr() - other

def __or__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any:
if isinstance(other, DuckDBSelector):

def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
rhs = other._call(df)
return [
*(x for x, name in zip(lhs, lhs_names) if name not in rhs_names),
*rhs,
]

def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return DuckDBSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
returns_scalar=self._returns_scalar,
version=self._version,
kwargs={},
)
else:
return self._to_expr() | other

def __and__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any:
if isinstance(other, DuckDBSelector):

def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
lhs = self._call(df)
return [x for x, name in zip(lhs, lhs_names) if name in rhs_names]

def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]:
lhs_names = self._evaluate_output_names(df)
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return DuckDBSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
returns_scalar=self._returns_scalar,
version=self._version,
kwargs={},
)
else:
return self._to_expr() & other

def __invert__(self: Self) -> DuckDBSelector:
return (
DuckDBSelectorNamespace(
backend_version=self._backend_version, version=self._version
).all()
- self
)
2 changes: 2 additions & 0 deletions narwhals/_pandas_like/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def numeric(self: Self) -> PandasSelector:
dtypes = import_dtypes_module(self._version)
return self.by_dtype(
[
dtypes.Int128,
dtypes.Int64,
dtypes.Int32,
dtypes.Int16,
dtypes.Int8,
dtypes.UInt128,
dtypes.UInt64,
dtypes.UInt32,
dtypes.UInt16,
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.selectors import SparkLikeSelectorNamespace
from narwhals.typing import CompliantNamespace

if TYPE_CHECKING:
Expand All @@ -34,6 +35,12 @@ def __init__(
self._backend_version = backend_version
self._version = version

@property
def selectors(self: Self) -> SparkLikeSelectorNamespace:
return SparkLikeSelectorNamespace(
backend_version=self._backend_version, version=self._version
)

def all(self: Self) -> SparkLikeExpr:
def _all(df: SparkLikeLazyFrame) -> list[Column]:
return [F.col(col_name) for col_name in df.columns]
Expand Down
Loading

0 comments on commit 364a625

Please sign in to comment.