Skip to content

Commit

Permalink
feat: support std and var with ddof !=1 in pandas-like group by (#1645)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Dec 22, 2024
1 parent e112a99 commit 9f09ea0
Show file tree
Hide file tree
Showing 16 changed files with 250 additions and 90 deletions.
18 changes: 9 additions & 9 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def alias(self: Self, name: str) -> Self:
output_names=[name],
backend_version=self._backend_version,
version=self._version,
kwargs={"name": name},
kwargs={**self._kwargs, "name": name},
)

def null_count(self: Self) -> Self:
Expand Down Expand Up @@ -436,7 +436,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=self._output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={"keys": keys},
kwargs={**self._kwargs, "keys": keys},
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -478,7 +478,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=self._output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={"function": function, "return_dtype": return_dtype},
kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype},
)

def is_finite(self: Self) -> Self:
Expand Down Expand Up @@ -810,7 +810,7 @@ def keep(self: Self) -> ArrowExpr:
output_names=root_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
Expand All @@ -837,7 +837,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"function": function},
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
Expand All @@ -862,7 +862,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"prefix": prefix},
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
Expand All @@ -888,7 +888,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"suffix": suffix},
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> ArrowExpr:
Expand All @@ -914,7 +914,7 @@ def to_lowercase(self: Self) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand All @@ -940,7 +940,7 @@ def to_uppercase(self: Self) -> ArrowExpr:
output_names=output_names,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)


Expand Down
7 changes: 6 additions & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
kwargs={"separator": separator, "ignore_nulls": ignore_nulls},
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)


Expand Down
6 changes: 3 additions & 3 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={"other": other},
kwargs={**self._kwargs, "other": other},
)
else:
return self._to_expr() - other
Expand All @@ -145,7 +145,7 @@ def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={"other": other},
kwargs={**self._kwargs, "other": other},
)
else:
return self._to_expr() | other
Expand All @@ -166,7 +166,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]:
output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={"other": other},
kwargs={**self._kwargs, "other": other},
)
else:
return self._to_expr() & other
Expand Down
18 changes: 9 additions & 9 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=self._returns_scalar or returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs=kwargs,
kwargs={**self._kwargs, **kwargs},
)

def alias(self, name: str) -> Self:
Expand All @@ -194,7 +194,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs={"name": name},
kwargs={**self._kwargs, "name": name},
)

def __add__(self, other: Any) -> Self:
Expand Down Expand Up @@ -847,7 +847,7 @@ def func(df: DaskLazyFrame) -> list[Any]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"keys": keys},
kwargs={**self._kwargs, "keys": keys},
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def keep(self: Self) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> DaskExpr:
Expand Down Expand Up @@ -1362,7 +1362,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"function": function},
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> DaskExpr:
Expand All @@ -1388,7 +1388,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"prefix": prefix},
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> DaskExpr:
Expand All @@ -1415,7 +1415,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={"suffix": suffix},
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> DaskExpr:
Expand All @@ -1442,7 +1442,7 @@ def to_lowercase(self: Self) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> DaskExpr:
Expand All @@ -1469,5 +1469,5 @@ def to_uppercase(self: Self) -> DaskExpr:
returns_scalar=self._compliant_expr._returns_scalar,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={},
kwargs=self._compliant_expr._kwargs,
)
21 changes: 13 additions & 8 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
Expand All @@ -178,7 +178,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
Expand All @@ -197,7 +197,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def concat(
Expand Down Expand Up @@ -279,7 +279,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
Expand All @@ -301,7 +301,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
Expand All @@ -323,7 +323,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"exprs": exprs},
)

def when(
Expand Down Expand Up @@ -388,7 +388,12 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)


Expand Down Expand Up @@ -451,7 +456,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen:
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
version=self._version,
kwargs={},
kwargs={"value": value},
)


Expand Down
5 changes: 2 additions & 3 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,13 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]:
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

return plx._create_expr_from_callable( # type: ignore[return-value]
func, # type: ignore[arg-type]
depth=expr._depth + 1,
function_name=f"{expr._function_name}->{attr}",
root_names=root_names,
output_names=output_names,
kwargs=kwargs,
kwargs={**expr._kwargs, **kwargs},
)


Expand Down Expand Up @@ -272,7 +271,7 @@ def reuse_series_namespace_implementation(
function_name=f"{expr._function_name}->{series_namespace}.{attr}",
root_names=expr._root_names,
output_names=expr._output_names,
kwargs=kwargs,
kwargs={**expr._kwargs, **kwargs},
)


Expand Down
Loading

0 comments on commit 9f09ea0

Please sign in to comment.