Skip to content

Commit

Permalink
feat: consistently raise ColumnNotFoundError for missing columns in…
Browse files Browse the repository at this point in the history
… `select` and `drop` (#1389)



---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
raisadz and MarcoGorelli authored Nov 17, 2024
1 parent 0908f20 commit 0a5c8b9
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 38 deletions.
26 changes: 17 additions & 9 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal
from typing import Sequence

from narwhals._exceptions import ColumnNotFoundError
from narwhals._expression_parsing import reuse_series_implementation
from narwhals._expression_parsing import reuse_series_namespace_implementation
from narwhals.dependencies import get_numpy
Expand Down Expand Up @@ -64,15 +65,22 @@ def from_column_names(
from narwhals._arrow.series import ArrowSeries

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [
ArrowSeries(
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
try:
return [
ArrowSeries(
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
) from e

return cls(
func,
Expand Down
10 changes: 9 additions & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._exceptions import ColumnNotFoundError
from narwhals._pandas_like.utils import calculate_timestamp_date
from narwhals._pandas_like.utils import calculate_timestamp_datetime
from narwhals._pandas_like.utils import native_to_narwhals_dtype
Expand Down Expand Up @@ -67,7 +68,14 @@ def from_column_names(
dtypes: DTypes,
) -> Self:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
return [df._native_frame[column_name] for column_name in column_names]
try:
return [df._native_frame[column_name] for column_name in column_names]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
) from e

return cls(
func,
Expand Down
31 changes: 30 additions & 1 deletion narwhals/_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,36 @@
from __future__ import annotations


class ColumnNotFoundError(Exception): ...
class FormattedKeyError(KeyError):
"""KeyError with formatted error message.
Python's `KeyError` has special casing around formatting
(see https://bugs.python.org/issue2651). Use this class when the error
message has newlines and other special format characters.
Needed by https://github.com/tensorflow/tensorflow/issues/36857.
"""

def __init__(self, message: str) -> None:
self.message = message

def __str__(self) -> str:
return self.message


class ColumnNotFoundError(FormattedKeyError):
def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)

@classmethod
def from_missing_and_available_column_names(
cls, missing_columns: list[str], available_columns: list[str]
) -> ColumnNotFoundError:
message = (
f"The following columns were not found: {missing_columns}"
f"\n\nHint: Did you mean one of these columns: {available_columns}?"
)
return ColumnNotFoundError(message)


class InvalidOperationError(Exception): ...
26 changes: 17 additions & 9 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Literal
from typing import Sequence

from narwhals._exceptions import ColumnNotFoundError
from narwhals._expression_parsing import reuse_series_implementation
from narwhals._expression_parsing import reuse_series_namespace_implementation
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -72,15 +73,22 @@ def from_column_names(
dtypes: DTypes,
) -> Self:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
return [
PandasLikeSeries(
df._native_frame[column_name],
implementation=df._implementation,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
try:
return [
PandasLikeSeries(
df._native_frame[column_name],
implementation=df._implementation,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
) from e

return cls(
func,
Expand Down
16 changes: 15 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
from narwhals._exceptions import ColumnNotFoundError
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -659,5 +660,18 @@ def select_columns_by_name(
):
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
# for why we need this
available_columns = df.columns.tolist() # type: ignore[attr-defined]
missing_columns = [x for x in column_names if x not in available_columns]
if missing_columns: # pragma: no cover
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns, available_columns
)
return df.loc[:, column_names] # type: ignore[no-any-return, attr-defined]
return df[column_names] # type: ignore[no-any-return, index]
try:
return df[column_names] # type: ignore[no-any-return, index]
except KeyError as e:
available_columns = df.columns.tolist() # type: ignore[attr-defined]
missing_columns = [x for x in column_names if x not in available_columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns, available_columns
) from e
33 changes: 23 additions & 10 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal
from typing import Sequence

from narwhals._exceptions import ColumnNotFoundError
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
Expand Down Expand Up @@ -79,10 +80,17 @@ def __getattr__(self, attr: str) -> Any:
}

def func(*args: Any, **kwargs: Any) -> Any:
import polars as pl # ignore-banned-import()

args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
return self._from_native_object(
getattr(self._native_frame, attr)(*args, **kwargs)
)
try:
return self._from_native_object(
getattr(self._native_frame, attr)(*args, **kwargs)
)
except pl.exceptions.ColumnNotFoundError as e:
msg = str(e)
msg += f"\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(str(e)) from e

return func

Expand Down Expand Up @@ -219,12 +227,10 @@ def with_row_index(self, name: str) -> Any:
return self._from_native_frame(self._native_frame.with_row_index(name))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0):
to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(to_drop))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(to_drop))

def unpivot(
self: Self,
Expand Down Expand Up @@ -341,8 +347,15 @@ def collect_schema(self) -> dict[str, DType]:
}

def collect(self) -> PolarsDataFrame:
import polars as pl # ignore-banned-import

try:
result = self._native_frame.collect()
except pl.exceptions.ColumnNotFoundError as e:
raise ColumnNotFoundError(str(e)) from e

return PolarsDataFrame(
self._native_frame.collect(),
result,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
Expand Down
14 changes: 7 additions & 7 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,16 +638,16 @@ def parse_columns_to_drop(
columns: Iterable[str],
strict: bool, # noqa: FBT001
) -> list[str]:
cols = set(compliant_frame.columns)
cols = compliant_frame.columns
to_drop = list(columns)

if strict:
for d in to_drop:
if d not in cols:
msg = f'"{d}" not found'
raise ColumnNotFoundError(msg)
missing_columns = [x for x in to_drop if x not in cols]
if missing_columns:
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns, available_columns=cols
)
else:
to_drop = list(cols.intersection(set(to_drop)))
to_drop = list(set(cols).intersection(set(to_drop)))
return to_drop


Expand Down
42 changes: 42 additions & 0 deletions tests/frame/select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals._exceptions import ColumnNotFoundError
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

Expand Down Expand Up @@ -53,3 +55,43 @@ def test_comparison_with_list_error_message() -> None:
nw.from_native(pa.chunked_array([[1, 2, 3]]), series_only=True) == [1, 2, 3] # noqa: B015
with pytest.raises(ValueError, match=msg):
nw.from_native(pd.Series([[1, 2, 3]]), series_only=True) == [1, 2, 3] # noqa: B015


def test_missing_columns(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
selected_columns = ["a", "e", "f"]
msg = (
r"The following columns were not found: \[.*\]"
r"\n\nHint: Did you mean one of these columns: \['a', 'b', 'z'\]?"
)
if "polars" in str(constructor):
# In the lazy case, Polars only errors when we call `collect`,
# and we have no way to recover exactly which columns the user
# tried selecting. So, we just emit their message (which varies
# across versions...)
msg = "e|f"
if isinstance(df, nw.LazyFrame):
with pytest.raises(ColumnNotFoundError, match=msg):
df.select(selected_columns).collect()
else:
with pytest.raises(ColumnNotFoundError, match=msg):
df.select(selected_columns)
if POLARS_VERSION >= (1,):
# Old Polars versions wouldn't raise an error
# at all here
if isinstance(df, nw.LazyFrame):
with pytest.raises(ColumnNotFoundError, match=msg):
df.drop(selected_columns, strict=True).collect()
else:
with pytest.raises(ColumnNotFoundError, match=msg):
df.drop(selected_columns, strict=True)
else: # pragma: no cover
pass
else:
with pytest.raises(ColumnNotFoundError, match=msg):
df.select(selected_columns)
with pytest.raises(ColumnNotFoundError, match=msg):
df.drop(selected_columns, strict=True)
with pytest.raises(ColumnNotFoundError, match=msg):
df.select(nw.col("fdfa"))

0 comments on commit 0a5c8b9

Please sign in to comment.