From 9e8cdee1571f7e3bcd74e2560f2e43b7d6f76275 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sat, 7 Dec 2024 13:35:36 +0000 Subject: [PATCH] move some imports deeper (#1532) --- narwhals/_arrow/dataframe.py | 24 +-- narwhals/_arrow/group_by.py | 10 +- narwhals/_arrow/namespace.py | 12 +- narwhals/_arrow/series.py | 226 ++++++++++++++--------------- narwhals/_arrow/utils.py | 32 ++-- narwhals/_dask/dataframe.py | 6 +- narwhals/_dask/expr.py | 4 +- narwhals/_dask/group_by.py | 2 +- narwhals/_dask/namespace.py | 10 +- narwhals/_dask/utils.py | 2 +- narwhals/_pandas_like/dataframe.py | 2 +- narwhals/_pandas_like/utils.py | 4 +- narwhals/_polars/dataframe.py | 10 +- narwhals/_polars/namespace.py | 30 ++-- narwhals/_polars/series.py | 8 +- narwhals/_polars/utils.py | 4 +- narwhals/translate.py | 44 ++++-- utils/import_check.py | 16 ++ 18 files changed, 241 insertions(+), 205 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 139a865b5..68f7ab534 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -294,7 +294,7 @@ def select( *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr, ) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa new_series = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: @@ -473,7 +473,7 @@ def to_dict( return {name: col.to_pylist() for name, col in names_and_values} def with_row_index(self: Self, name: str) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa df = self._native_frame @@ -500,7 +500,7 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self: return self._from_native_frame(self._native_frame.filter(mask)) def null_count(self: Self) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa df = self._native_frame names_and_values = zip(df.column_names, df.columns) @@ -572,12 +572,12 @@ def rename(self: Self, mapping: dict[str, str]) -> Self: return self._from_native_frame(df.rename_columns(new_cols)) def write_parquet(self: Self, file: Any) -> None: - import pyarrow.parquet as pp # ignore-banned-import + import pyarrow.parquet as pp pp.write_table(self._native_frame, file) def write_csv(self: Self, file: Any) -> Any: - import pyarrow as pa # ignore-banned-import + import pyarrow as pa import pyarrow.csv as pa_csv # ignore-banned-import pa_table = self._native_frame @@ -589,8 +589,8 @@ def write_csv(self: Self, file: Any) -> Any: def is_duplicated(self: Self) -> ArrowSeries: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc from narwhals._arrow.series import ArrowSeries @@ -617,7 +617,7 @@ def is_duplicated(self: Self) -> ArrowSeries: ) def is_unique(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc from narwhals._arrow.series import ArrowSeries @@ -640,8 +640,8 @@ def unique( # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc df = self._native_frame @@ -681,7 +681,7 @@ def sample( seed: int | None, ) -> Self: import numpy as np # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc frame = self._native_frame num_rows = len(self) @@ -701,7 +701,7 @@ def unpivot( variable_name: str | None, value_name: str | None, ) -> Self: - import pyarrow as pa # ignore-banned-import + import pyarrow as pa native_frame = self._native_frame variable_name = variable_name if variable_name is not None else "variable" diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 6bf9fa88a..f8f5c9921 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -33,7 +33,7 @@ def get_function_name_option( function_name: str, ) -> pc.CountOptions | pc.VarianceOptions | None: """Map specific pyarrow compute function to respective option to match polars behaviour.""" - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc function_name_to_options = { "count": pc.CountOptions(mode="all"), @@ -48,7 +48,7 @@ class ArrowGroupBy: def __init__( self: Self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool ) -> None: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa if drop_null_keys: self._df = df.drop_nulls(keys) @@ -87,8 +87,8 @@ def agg( ) def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) null_token = "__null_token_value__" # noqa: S105 @@ -127,7 +127,7 @@ def agg_arrow( output_names: list[str], from_dataframe: Callable[[Any], ArrowDataFrame], ) -> ArrowDataFrame: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc all_simple_aggs = True for expr in exprs: diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index aacc2a2f5..033b69da8 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -80,7 +80,7 @@ def _create_series_from_scalar( ) def _create_compliant_series(self: Self, value: Any) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa from narwhals._arrow.series import ArrowSeries @@ -254,7 +254,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) def min_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc parsed_exprs = parse_into_exprs(*exprs, namespace=self) @@ -282,7 +282,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) def max_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc parsed_exprs = parse_into_exprs(*exprs, namespace=self) @@ -385,7 +385,7 @@ def concat_str( separator: str, ignore_nulls: bool, ) -> ArrowExpr: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc parsed_exprs: list[ArrowExpr] = [ *parse_into_exprs(*exprs, namespace=self), @@ -438,8 +438,8 @@ def __init__( self._version = version def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc from narwhals._arrow.namespace import ArrowNamespace from narwhals._expression_parsing import parse_into_expr diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index f40ec4228..51c1b9943 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -62,7 +62,7 @@ def _change_version(self: Self, version: Version) -> Self: ) def _from_native_series(self: Self, series: pa.ChunkedArray | pa.Array) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa if isinstance(series, pa.Array): series = pa.chunked_array([series]) @@ -82,7 +82,7 @@ def _from_iterable( backend_version: tuple[int, ...], version: Version, ) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa return cls( pa.chunked_array([data]), @@ -108,61 +108,61 @@ def __eq__(self: Self, other: object) -> Self: # type: ignore[override] return self._from_native_series(pc.equal(ser, other)) def __ne__(self: Self, other: object) -> Self: # type: ignore[override] - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.not_equal(ser, other)) def __ge__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.greater_equal(ser, other)) def __gt__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.greater(ser, other)) def __le__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.less_equal(ser, other)) def __lt__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.less(ser, other)) def __and__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.and_kleene(ser, other)) def __rand__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.and_kleene(other, ser)) def __or__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.or_kleene(ser, other)) def __ror__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.or_kleene(other, ser)) def __add__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.add(ser, other)) @@ -171,7 +171,7 @@ def __radd__(self: Self, other: Any) -> Self: return self + other # type: ignore[no-any-return] def __sub__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.subtract(ser, other)) @@ -180,7 +180,7 @@ def __rsub__(self: Self, other: Any) -> Self: return (self - other) * (-1) # type: ignore[no-any-return] def __mul__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.multiply(ser, other)) @@ -189,13 +189,13 @@ def __rmul__(self: Self, other: Any) -> Self: return self * other # type: ignore[no-any-return] def __pow__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.power(ser, other)) def __rpow__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) return self._from_native_series(pc.power(other, ser)) @@ -209,8 +209,8 @@ def __rfloordiv__(self: Self, other: Any) -> Self: return self._from_native_series(floordiv_compat(other, ser)) def __truediv__(self: Self, other: Any) -> Self: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): @@ -219,8 +219,8 @@ def __truediv__(self: Self, other: Any) -> Self: return self._from_native_series(pc.divide(*cast_for_truediv(ser, other))) def __rtruediv__(self: Self, other: Any) -> Self: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc ser, other = broadcast_and_extract_native(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): @@ -229,7 +229,7 @@ def __rtruediv__(self: Self, other: Any) -> Self: return self._from_native_series(pc.divide(*cast_for_truediv(other, ser))) def __mod__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc floor_div = (self // other)._native_series ser, other = broadcast_and_extract_native(self, other, self._backend_version) @@ -237,7 +237,7 @@ def __mod__(self: Self, other: Any) -> Self: return self._from_native_series(res) def __rmod__(self: Self, other: Any) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc floor_div = (other // self)._native_series ser, other = broadcast_and_extract_native(self, other, self._backend_version) @@ -245,7 +245,7 @@ def __rmod__(self: Self, other: Any) -> Self: return self._from_native_series(res) def __invert__(self: Self) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series(pc.invert(self._native_series)) @@ -260,12 +260,12 @@ def filter(self: Self, other: Any) -> Self: return self._from_native_series(ser.filter(other)) def mean(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.mean(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def median(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc from narwhals.exceptions import InvalidOperationError @@ -278,27 +278,27 @@ def median(self: Self, *, _return_py_scalar: bool = True) -> int: ) def min(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.min(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def max(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def sum(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.sum(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def drop_nulls(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series(pc.drop_null(self._native_series)) def shift(self: Self, n: int) -> Self: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa ca = self._native_series @@ -311,14 +311,14 @@ def shift(self: Self, n: int) -> Self: return self._from_native_series(result) def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar( # type: ignore[no-any-return] pc.stddev(self._native_series, ddof=ddof), _return_py_scalar ) def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser = self._native_series ser_not_null = pc.drop_null(ser) @@ -338,12 +338,12 @@ def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: ) def count(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc unique_values = pc.unique(self._native_series) return maybe_extract_py_scalar( # type: ignore[no-any-return] @@ -379,8 +379,8 @@ def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self: def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc mask = np.zeros(self.len(), dtype=bool) mask[indices] = True @@ -419,12 +419,12 @@ def dtype(self: Self) -> DType: return native_to_narwhals_dtype(self._native_series.type, self._version) def abs(self: Self) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series(pc.abs(self._native_series)) def cum_sum(self: Self, *, reverse: bool) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc native_series = self._native_series result = ( @@ -435,33 +435,33 @@ def cum_sum(self: Self, *, reverse: bool) -> Self: return self._from_native_series(result) def round(self: Self, decimals: int) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series( pc.round(self._native_series, decimals, round_mode="half_towards_infinity") ) def diff(self: Self) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series( pc.pairwise_diff(self._native_series.combine_chunks()) ) def any(self: Self, *, _return_py_scalar: bool = True) -> bool: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.any(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def all(self: Self, *, _return_py_scalar: bool = True) -> bool: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar(pc.all(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def is_between( self, lower_bound: Any, upper_bound: Any, closed: str = "both" ) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser = self._native_series if closed == "left": @@ -492,7 +492,7 @@ def is_null(self: Self) -> Self: return self._from_native_series(ser.is_null()) def cast(self: Self, dtype: DType) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser = self._native_series dtype = narwhals_to_native_dtype(dtype, self._version) @@ -518,8 +518,8 @@ def tail(self: Self, n: int) -> Self: return self._from_native_series(ser.slice(abs(n))) def is_in(self: Self, other: Any) -> Self: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc value_set = pa.array(other) ser = self._native_series @@ -557,8 +557,8 @@ def value_counts( normalize: bool = False, ) -> ArrowDataFrame: """Parallel is unused, exists for compatibility.""" - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc from narwhals._arrow.dataframe import ArrowDataFrame @@ -584,7 +584,7 @@ def value_counts( ) def zip_with(self: Self, mask: Self, other: Self) -> Self: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc mask = mask._native_series.combine_chunks() return self._from_native_series( @@ -604,7 +604,7 @@ def sample( seed: int | None, ) -> Self: import numpy as np # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser = self._native_series num_rows = len(self) @@ -625,8 +625,8 @@ def fill_null( limit: int | None, ) -> Self: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc def fill_aux( arr: pa.Array, @@ -671,7 +671,7 @@ def fill_aux( return res_ser def to_frame(self: Self) -> ArrowDataFrame: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa from narwhals._arrow.dataframe import ArrowDataFrame @@ -693,8 +693,8 @@ def is_unique(self: Self) -> ArrowSeries: def is_first_distinct(self: Self) -> Self: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc row_number = pa.array(np.arange(len(self))) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) @@ -710,8 +710,8 @@ def is_first_distinct(self: Self) -> Self: def is_last_distinct(self: Self) -> Self: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc row_number = pa.array(np.arange(len(self))) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) @@ -729,7 +729,7 @@ def is_sorted(self: Self, *, descending: bool) -> bool: if not isinstance(descending, bool): msg = f"argument 'descending' should be boolean, found {type(descending)}" raise TypeError(msg) - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc ser = self._native_series if descending: @@ -741,15 +741,15 @@ def is_sorted(self: Self, *, descending: bool) -> bool: def unique(self: Self, *, maintain_order: bool) -> ArrowSeries: # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._from_native_series(pc.unique(self._native_series)) def replace_strict( self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc # https://stackoverflow.com/a/79111029/4451315 idxs = pc.index_in(self._native_series, pa.array(old)) @@ -767,7 +767,7 @@ def replace_strict( return result def sort(self: Self, *, descending: bool, nulls_last: bool) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc series = self._native_series order = "descending" if descending else "ascending" @@ -780,7 +780,7 @@ def sort(self: Self, *, descending: bool, nulls_last: bool) -> ArrowSeries: def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFrame: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa from narwhals._arrow.dataframe import ArrowDataFrame @@ -819,7 +819,7 @@ def quantile( *, _return_py_scalar: bool = True, ) -> Any: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return maybe_extract_py_scalar( pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0], @@ -830,8 +830,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self._from_native_series(self._native_series[offset::n]) def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._native_series arr = pc.max_element_wise(arr, pa.scalar(lower_bound, type=arr.type)) @@ -850,7 +850,7 @@ def mode(self: Self) -> ArrowSeries: )[self.name] def is_finite(self: Self) -> Self: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc return self._from_native_series(pc.is_finite(self._native_series)) @@ -863,7 +863,7 @@ def cum_min(self: Self, *, reverse: bool) -> Self: msg = "cum_min method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc native_series = self._native_series @@ -879,7 +879,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc native_series = self._native_series @@ -895,7 +895,7 @@ def cum_prod(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc native_series = self._native_series @@ -913,8 +913,8 @@ def rolling_sum( min_periods: int | None, center: bool, ) -> Self: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc min_periods = min_periods if min_periods is not None else window_size if center: @@ -966,8 +966,8 @@ def rolling_mean( min_periods: int | None, center: bool, ) -> Self: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc min_periods = min_periods if min_periods is not None else window_size if center: @@ -1027,8 +1027,8 @@ def __contains__(self: Self, other: Any) -> bool: from pyarrow import ArrowTypeError # ignore-banned-imports try: - import pyarrow as pa # ignore-banned-imports - import pyarrow.compute as pc # ignore-banned-imports + import pyarrow as pa + import pyarrow.compute as pc native_series = self._native_series other_ = ( @@ -1068,7 +1068,7 @@ def __init__(self: Self, series: ArrowSeries) -> None: self._arrow_series = series def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc # PyArrow differs from other libraries in that %S also prints out # the fractional part of the second...:'( @@ -1079,7 +1079,7 @@ def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 ) def replace_time_zone(self: Self, time_zone: str | None) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc if time_zone is not None: result = pc.assume_timezone( @@ -1090,7 +1090,7 @@ def replace_time_zone(self: Self, time_zone: str | None) -> ArrowSeries: return self._arrow_series._from_native_series(result) def convert_time_zone(self: Self, time_zone: str) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import + import pyarrow as pa if self._arrow_series.dtype.time_zone is None: # type: ignore[attr-defined] result = self.replace_time_zone("UTC")._native_series.cast( @@ -1104,8 +1104,8 @@ def convert_time_zone(self: Self, time_zone: str) -> ArrowSeries: return self._arrow_series._from_native_series(result) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ArrowSeries: - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc s = self._arrow_series._native_series dtype = self._arrow_series.dtype @@ -1158,63 +1158,63 @@ def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ArrowS return self._arrow_series._from_native_series(result) def date(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa return self._arrow_series._from_native_series( self._arrow_series._native_series.cast(pa.date32()) ) def year(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.year(self._arrow_series._native_series) ) def month(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.month(self._arrow_series._native_series) ) def day(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.day(self._arrow_series._native_series) ) def hour(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.hour(self._arrow_series._native_series) ) def minute(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.minute(self._arrow_series._native_series) ) def second(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.second(self._arrow_series._native_series) ) def millisecond(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.millisecond(self._arrow_series._native_series) ) def microsecond(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc arr = self._arrow_series._native_series result = pc.add(pc.multiply(pc.millisecond(arr), 1000), pc.microsecond(arr)) @@ -1222,7 +1222,7 @@ def microsecond(self: Self) -> ArrowSeries: return self._arrow_series._from_native_series(result) def nanosecond(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc arr = self._arrow_series._native_series result = pc.add( @@ -1231,15 +1231,15 @@ def nanosecond(self: Self) -> ArrowSeries: return self._arrow_series._from_native_series(result) def ordinal_day(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.day_of_year(self._arrow_series._native_series) ) def total_minutes(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._arrow_series._native_series unit = arr.type.unit @@ -1257,8 +1257,8 @@ def total_minutes(self: Self) -> ArrowSeries: ) def total_seconds(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._arrow_series._native_series unit = arr.type.unit @@ -1276,8 +1276,8 @@ def total_seconds(self: Self) -> ArrowSeries: ) def total_milliseconds(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._arrow_series._native_series unit = arr.type.unit @@ -1301,8 +1301,8 @@ def total_milliseconds(self: Self) -> ArrowSeries: ) def total_microseconds(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._arrow_series._native_series unit = arr.type.unit @@ -1325,8 +1325,8 @@ def total_microseconds(self: Self) -> ArrowSeries: ) def total_nanoseconds(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow as pa + import pyarrow.compute as pc arr = self._arrow_series._native_series unit = arr.type.unit @@ -1350,7 +1350,7 @@ def __init__(self: Self, series: ArrowSeries) -> None: self._arrow_series = series def get_categories(self: Self) -> ArrowSeries: - import pyarrow as pa # ignore-banned-import() + import pyarrow as pa ca = self._arrow_series._native_series # TODO(Unassigned): this looks potentially expensive - is there no better way? @@ -1366,7 +1366,7 @@ def __init__(self: Self, series: ArrowSeries) -> None: self._arrow_series = series def len_chars(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.utf8_length(self._arrow_series._native_series) @@ -1375,7 +1375,7 @@ def len_chars(self: Self) -> ArrowSeries: def replace( self: Self, pattern: str, value: str, *, literal: bool, n: int ) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc method = "replace_substring" if literal else "replace_substring_regex" return self._arrow_series._from_native_series( @@ -1393,7 +1393,7 @@ def replace_all( return self.replace(pattern, value, literal=literal, n=-1) def strip_chars(self: Self, characters: str | None) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc whitespace = " \t\n\r\v\f" return self._arrow_series._from_native_series( @@ -1404,21 +1404,21 @@ def strip_chars(self: Self, characters: str | None) -> ArrowSeries: ) def starts_with(self: Self, prefix: str) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.equal(self.slice(0, len(prefix))._native_series, prefix) ) def ends_with(self: Self, suffix: str) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.equal(self.slice(-len(suffix), None)._native_series, suffix) ) def contains(self: Self, pattern: str, *, literal: bool) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc check_func = pc.match_substring if literal else pc.match_substring_regex return self._arrow_series._from_native_series( @@ -1426,7 +1426,7 @@ def contains(self: Self, pattern: str, *, literal: bool) -> ArrowSeries: ) def slice(self: Self, offset: int, length: int | None) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc stop = offset + length if length is not None else None return self._arrow_series._from_native_series( @@ -1436,7 +1436,7 @@ def slice(self: Self, offset: int, length: int | None) -> ArrowSeries: ) def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002 - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc if format is None: format = parse_datetime_format(self._arrow_series._native_series) @@ -1446,14 +1446,14 @@ def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002 ) def to_uppercase(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.utf8_upper(self._arrow_series._native_series), ) def to_lowercase(self: Self) -> ArrowSeries: - import pyarrow.compute as pc # ignore-banned-import() + import pyarrow.compute as pc return self._arrow_series._from_native_series( pc.utf8_lower(self._arrow_series._native_series), diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index d79de38f6..80cdb1a8a 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -18,7 +18,7 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: - import pyarrow as pa # ignore-banned-import + import pyarrow as pa dtypes = import_dtypes_module(version) if pa.types.is_int64(dtype): @@ -80,7 +80,7 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa.DataType: - import pyarrow as pa # ignore-banned-import + import pyarrow as pa dtypes = import_dtypes_module(version) if isinstance_or_issubclass(dtype, dtypes.Float64): @@ -170,7 +170,7 @@ def broadcast_and_extract_native( if len(lhs) == 1: # broadcast import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import + import pyarrow as pa fill_value = lhs[0] if backend_version < (13,) and hasattr(fill_value, "as_py"): @@ -201,7 +201,7 @@ def validate_dataframe_comparand( if isinstance(other, ArrowSeries): if len(other) == 1: import numpy as np # ignore-banned-import - import pyarrow as pa # ignore-banned-import + import pyarrow as pa value = other._native_series[0] if backend_version < (13,) and hasattr(value, "as_py"): @@ -223,7 +223,7 @@ def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: Should be in namespace. """ - import pyarrow as pa # ignore-banned-import + import pyarrow as pa names = [name for df in dfs for name in df.column_names] @@ -251,7 +251,7 @@ def vertical_concat(dfs: list[pa.Table]) -> pa.Table: ) raise TypeError(msg) - import pyarrow as pa # ignore-banned-import + import pyarrow as pa return pa.concat_tables(dfs).combine_chunks() @@ -261,7 +261,7 @@ def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa Should be in namespace. """ - import pyarrow as pa # ignore-banned-import + import pyarrow as pa kwargs = ( {"promote": True} @@ -274,8 +274,8 @@ def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc if isinstance(left, (int, float)): left = pa.scalar(left) @@ -315,8 +315,8 @@ def cast_for_truediv( ) -> tuple[pa.ChunkedArray | pa.Scalar, pa.ChunkedArray | pa.Scalar]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int @@ -338,7 +338,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: if fast_path: return [s._native_series for s in series] - import pyarrow as pa # ignore-banned-import + import pyarrow as pa is_max_length_gt_1 = max_length > 1 reshaped = [] @@ -427,8 +427,8 @@ def convert_str_slice_to_int_slice( def parse_datetime_format(arr: pa.StringArray) -> str: """Try to infer datetime format from StringArray.""" - import pyarrow as pa # ignore-banned-import - import pyarrow.compute as pc # ignore-banned-import + import pyarrow as pa + import pyarrow.compute as pc matches = pa.concat_arrays( # converts from ChunkedArray to StructArray pc.extract_regex(pc.drop_null(arr).slice(0, 10), pattern=FULL_RE).chunks @@ -465,7 +465,7 @@ def parse_datetime_format(arr: pa.StringArray) -> str: def _parse_date_format(arr: pa.Array) -> str: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc for date_rgx, date_fmt in DATE_FORMATS: matches = pc.extract_regex(arr, pattern=date_rgx) @@ -487,7 +487,7 @@ def _parse_date_format(arr: pa.Array) -> str: def _parse_time_format(arr: pa.Array) -> str: - import pyarrow.compute as pc # ignore-banned-import + import pyarrow.compute as pc for time_rgx, time_fmt in TIME_FORMATS: matches = pc.extract_regex(arr, pattern=time_rgx) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 98f7745ad..4184ae409 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -76,7 +76,7 @@ def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: return self._from_native_frame(df) def collect(self) -> Any: - import pandas as pd # ignore-banned-import() + import pandas as pd from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -119,7 +119,7 @@ def select( *exprs: IntoDaskExpr, **named_exprs: IntoDaskExpr, ) -> Self: - import dask.dataframe as dd # ignore-banned-import + import dask.dataframe as dd if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs: # This is a simple slice => fastpath! @@ -136,7 +136,7 @@ def select( if not new_series: # return empty dataframe, like Polars does - import pandas as pd # ignore-banned-import + import pandas as pd return self._from_native_frame( dd.from_pandas(pd.DataFrame(), npartitions=self._native_frame.npartitions) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 01d312949..76b9cd431 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -856,7 +856,7 @@ def func(_input: Any, dtype: DType | type[DType]) -> Any: ) def is_finite(self: Self) -> Self: - import dask.array as da # ignore-banned-import + import dask.array as da return self._from_call( lambda _input: da.isfinite(_input), @@ -1008,7 +1008,7 @@ def slice(self, offset: int, length: int | None = None) -> DaskExpr: ) def to_datetime(self: Self, format: str | None) -> DaskExpr: # noqa: A002 - import dask.dataframe as dd # ignore-banned-import() + import dask.dataframe as dd return self._expr._from_call( lambda _input, fmt: dd.to_datetime(_input, format=fmt), diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index e7ff8b77a..3bc21efb9 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -19,7 +19,7 @@ def n_unique() -> dd.Aggregation: - import dask.dataframe as dd # ignore-banned-import + import dask.dataframe as dd def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int: return s.nunique(dropna=False) # type: ignore[no-any-return] diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 3478a0a4d..b3e2814ca 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -115,8 +115,8 @@ def sum(self, *column_names: str) -> DaskExpr: ).sum() def len(self) -> DaskExpr: - import dask.dataframe as dd # ignore-banned-import - import pandas as pd # ignore-banned-import + import dask.dataframe as dd + import pandas as pd def func(df: DaskLazyFrame) -> list[dask_expr.Series]: if not df.columns: @@ -200,7 +200,7 @@ def concat( *, how: Literal["horizontal", "vertical", "diagonal"], ) -> DaskLazyFrame: - import dask.dataframe as dd # ignore-banned-import + import dask.dataframe as dd if len(list(items)) == 0: msg = "No items to concatenate" # pragma: no cover @@ -276,7 +276,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: ) def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: - import dask.dataframe as dd # ignore-banned-import + import dask.dataframe as dd parsed_exprs = parse_into_exprs(*exprs, namespace=self) @@ -297,7 +297,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: ) def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: - import dask.dataframe as dd # ignore-banned-import + import dask.dataframe as dd parsed_exprs = parse_into_exprs(*exprs, namespace=self) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index cdcef2a0d..86a0e5193 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -68,7 +68,7 @@ def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame: def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: - import dask_expr # ignore-banned-import + import dask_expr if not dask_expr._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover # are_co_aligned is a method which cheaply checks if two Dask expressions diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index d9c47a184..a897548bf 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -740,7 +740,7 @@ def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any: # returns Object) then we just call `to_numpy()` on the DataFrame. for col_dtype in df.dtypes: if str(col_dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING: - import numpy as np # ignore-banned-import + import numpy as np return np.hstack( [self[col].to_numpy(copy=copy)[:, None] for col in self.columns] diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 4224e2925..7ea360836 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -392,7 +392,7 @@ def native_to_narwhals_dtype( if implementation is Implementation.PANDAS: # pragma: no cover # This is the most efficient implementation for pandas, # and doesn't require the interchange protocol - import pandas as pd # ignore-banned-import + import pandas as pd dtype = pd.api.types.infer_dtype(native_column, skipna=True) if dtype == "string": @@ -416,7 +416,7 @@ def native_to_narwhals_dtype( def get_dtype_backend(dtype: Any, implementation: Implementation) -> str: if implementation is Implementation.PANDAS: - import pandas as pd # ignore-banned-import() + import pandas as pd if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype): return "pyarrow-nullable" diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index bf3e1cc29..643675d14 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -86,7 +86,7 @@ def _from_native_object(self: Self, obj: T) -> T: ... def _from_native_object( self: Self, obj: pl.Series | pl.DataFrame | T ) -> Self | PolarsSeries | T: - import polars as pl # ignore-banned-import() + import polars as pl if isinstance(obj, pl.Series): from narwhals._polars.series import PolarsSeries @@ -101,7 +101,7 @@ def _from_native_object( def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: - import polars as pl # ignore-banned-import() + import polars as pl args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] try: @@ -184,7 +184,7 @@ def __getitem__(self: Self, item: Any) -> Any: ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover - import polars as pl # ignore-banned-import() + import polars as pl if ( isinstance(item, tuple) @@ -376,7 +376,7 @@ def _change_version(self: Self, version: Version) -> Self: def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: - import polars as pl # ignore-banned-import + import polars as pl args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] try: @@ -425,7 +425,7 @@ def collect_schema(self: Self) -> dict[str, DType]: } def collect(self: Self) -> PolarsDataFrame: - import polars as pl # ignore-banned-import + import polars as pl try: result = self._native_frame.collect() diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 1459a13bc..2363e2895 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -32,7 +32,7 @@ def __init__( self._version = version def __getattr__(self: Self, attr: str) -> Any: - import polars as pl # ignore-banned-import + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -47,7 +47,7 @@ def func(*args: Any, **kwargs: Any) -> Any: return func def nth(self: Self, *indices: int) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -59,7 +59,7 @@ def nth(self: Self, *indices: int) -> PolarsExpr: ) def len(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -95,7 +95,7 @@ def concat( *, how: Literal["vertical", "horizontal", "diagonal"], ) -> PolarsDataFrame | PolarsLazyFrame: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame @@ -113,7 +113,7 @@ def concat( ) def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -128,7 +128,7 @@ def lit(self: Self, value: Any, dtype: DType | None = None) -> PolarsExpr: ) def mean(self: Self, *column_names: str) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -145,7 +145,7 @@ def mean(self: Self, *column_names: str) -> PolarsExpr: ) def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -166,7 +166,7 @@ def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: ) def median(self: Self, *column_names: str) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -183,7 +183,7 @@ def concat_str( separator: str, ignore_nulls: bool, ) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -246,7 +246,7 @@ def __init__(self: Self, version: Version, backend_version: tuple[int, ...]) -> self._backend_version = backend_version def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -259,7 +259,7 @@ def by_dtype(self: Self, dtypes: Iterable[DType]) -> PolarsExpr: ) def numeric(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -270,7 +270,7 @@ def numeric(self: Self) -> PolarsExpr: ) def boolean(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -281,7 +281,7 @@ def boolean(self: Self) -> PolarsExpr: ) def string(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -292,7 +292,7 @@ def string(self: Self) -> PolarsExpr: ) def categorical(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr @@ -303,7 +303,7 @@ def categorical(self: Self) -> PolarsExpr: ) def all(self: Self) -> PolarsExpr: - import polars as pl # ignore-banned-import() + import polars as pl from narwhals._polars.expr import PolarsExpr diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 0bfea20ad..c960e8441 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -74,7 +74,7 @@ def _from_native_object(self: Self, series: T) -> T: ... def _from_native_object( self: Self, series: pl.Series | pl.DataFrame | T ) -> Self | PolarsDataFrame | T: - import polars as pl # ignore-banned-import() + import polars as pl if isinstance(series, pl.Series): return self._from_native_series(series) @@ -230,7 +230,7 @@ def median(self: Self) -> Any: return self._native_series.median() def to_dummies(self: Self, *, separator: str, drop_first: bool) -> PolarsDataFrame: - import polars as pl # ignore-banned-import + import polars as pl from narwhals._polars.dataframe import PolarsDataFrame @@ -283,7 +283,7 @@ def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: result = self._native_series.sort(descending=descending) if nulls_last: - import polars as pl # ignore-banned-import() + import polars as pl is_null = result.is_null() result = pl.concat([result.filter(~is_null), result.filter(is_null)]) @@ -311,7 +311,7 @@ def value_counts( from narwhals._polars.dataframe import PolarsDataFrame if self._backend_version < (1, 0, 0): - import polars as pl # ignore-banned-import() + import polars as pl value_name_ = name or ("proportion" if normalize else "count") diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 35f51b02d..97406b644 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -69,7 +69,7 @@ def native_to_narwhals_dtype( version: Version, backend_version: tuple[int, ...], ) -> DType: - import polars as pl # ignore-banned-import() + import polars as pl dtypes = import_dtypes_module(version) if dtype == pl.Float64: @@ -140,7 +140,7 @@ def native_to_narwhals_dtype( def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pl.DataType: - import polars as pl # ignore-banned-import() + import polars as pl dtypes = import_dtypes_module(version) diff --git a/narwhals/translate.py b/narwhals/translate.py index e673d7e35..7c2c6d36b 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -403,18 +403,6 @@ def _from_native_impl( # noqa: PLR0915 allow_series: bool | None = None, version: Version, ) -> Any: - from narwhals._arrow.dataframe import ArrowDataFrame - from narwhals._arrow.series import ArrowSeries - from narwhals._dask.dataframe import DaskLazyFrame - from narwhals._duckdb.dataframe import DuckDBInterchangeFrame - from narwhals._ibis.dataframe import IbisInterchangeFrame - from narwhals._interchange.dataframe import InterchangeFrame - from narwhals._pandas_like.dataframe import PandasLikeDataFrame - from narwhals._pandas_like.series import PandasLikeSeries - from narwhals._polars.dataframe import PolarsDataFrame - from narwhals._polars.dataframe import PolarsLazyFrame - from narwhals._polars.series import PolarsSeries - from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series @@ -475,6 +463,8 @@ def _from_native_impl( # noqa: PLR0915 # Polars elif is_polars_dataframe(native_object): + from narwhals._polars.dataframe import PolarsDataFrame + if series_only: if not pass_through: msg = "Cannot only use `series_only` with polars.DataFrame" @@ -490,6 +480,8 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) elif is_polars_lazyframe(native_object): + from narwhals._polars.dataframe import PolarsLazyFrame + if series_only: if not pass_through: msg = "Cannot only use `series_only` with polars.LazyFrame" @@ -510,6 +502,8 @@ def _from_native_impl( # noqa: PLR0915 level="lazy", ) elif is_polars_series(native_object): + from narwhals._polars.series import PolarsSeries + pl = get_polars() if not allow_series: if not pass_through: @@ -527,6 +521,8 @@ def _from_native_impl( # noqa: PLR0915 # pandas elif is_pandas_dataframe(native_object): + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + if series_only: if not pass_through: msg = "Cannot only use `series_only` with dataframe" @@ -543,6 +539,8 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) elif is_pandas_series(native_object): + from narwhals._pandas_like.series import PandasLikeSeries + if not allow_series: if not pass_through: msg = "Please set `allow_series=True` or `series_only=True`" @@ -561,6 +559,8 @@ def _from_native_impl( # noqa: PLR0915 # Modin elif is_modin_dataframe(native_object): # pragma: no cover + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + mpd = get_modin() if series_only: if not pass_through: @@ -577,6 +577,8 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) elif is_modin_series(native_object): # pragma: no cover + from narwhals._pandas_like.series import PandasLikeSeries + mpd = get_modin() if not allow_series: if not pass_through: @@ -595,6 +597,8 @@ def _from_native_impl( # noqa: PLR0915 # cuDF elif is_cudf_dataframe(native_object): # pragma: no cover + from narwhals._pandas_like.dataframe import PandasLikeDataFrame + cudf = get_cudf() if series_only: if not pass_through: @@ -611,6 +615,8 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) elif is_cudf_series(native_object): # pragma: no cover + from narwhals._pandas_like.series import PandasLikeSeries + cudf = get_cudf() if not allow_series: if not pass_through: @@ -629,6 +635,8 @@ def _from_native_impl( # noqa: PLR0915 # PyArrow elif is_pyarrow_table(native_object): + from narwhals._arrow.dataframe import ArrowDataFrame + pa = get_pyarrow() if series_only: if not pass_through: @@ -644,6 +652,8 @@ def _from_native_impl( # noqa: PLR0915 level="full", ) elif is_pyarrow_chunked_array(native_object): + from narwhals._arrow.series import ArrowSeries + pa = get_pyarrow() if not allow_series: if not pass_through: @@ -662,6 +672,8 @@ def _from_native_impl( # noqa: PLR0915 # Dask elif is_dask_dataframe(native_object): + from narwhals._dask.dataframe import DaskLazyFrame + if series_only: if not pass_through: msg = "Cannot only use `series_only` with dask DataFrame" @@ -686,6 +698,8 @@ def _from_native_impl( # noqa: PLR0915 # DuckDB elif is_duckdb_relation(native_object): + from narwhals._duckdb.dataframe import DuckDBInterchangeFrame + if eager_only or series_only: # pragma: no cover if not pass_through: msg = ( @@ -702,6 +716,8 @@ def _from_native_impl( # noqa: PLR0915 # Ibis elif is_ibis_table(native_object): # pragma: no cover + from narwhals._ibis.dataframe import IbisInterchangeFrame + if eager_only or series_only: if not pass_through: msg = ( @@ -717,6 +733,8 @@ def _from_native_impl( # noqa: PLR0915 # PySpark elif is_pyspark_dataframe(native_object): # pragma: no cover + from narwhals._spark_like.dataframe import SparkLikeLazyFrame + if series_only: msg = "Cannot only use `series_only` with pyspark DataFrame" raise TypeError(msg) @@ -734,6 +752,8 @@ def _from_native_impl( # noqa: PLR0915 # Interchange protocol elif hasattr(native_object, "__dataframe__"): + from narwhals._interchange.dataframe import InterchangeFrame + if eager_only or series_only: if not pass_through: msg = ( diff --git a/utils/import_check.py b/utils/import_check.py index e8d776cde..eee35dfc4 100644 --- a/utils/import_check.py +++ b/utils/import_check.py @@ -15,6 +15,14 @@ "pandas", "polars", "pyarrow", + "pyspark", +} + +ALLOWED_IMPORTS = { + "_pandas_like": {"pandas", "numpy"}, + "_arrow": {"pyarrow", "pyarrow.compute", "pyarrow.parquet"}, + "_dask": {"dask.dataframe", "pandas", "dask_expr"}, + "_polars": {"polars"}, } @@ -23,6 +31,12 @@ def __init__(self, file_name: str, lines: list[str]) -> None: self.file_name = file_name self.lines = lines self.found_import = False + for key, val in ALLOWED_IMPORTS.items(): + if key in self.file_name: + self.allowed_imports: set[str] = val + break + else: + self.allowed_imports = set() def visit_If(self, node: ast.If) -> None: # noqa: N802 # Check if the condition is `if TYPE_CHECKING` @@ -35,12 +49,14 @@ def visit_Import(self, node: ast.Import) -> None: # noqa: N802 for alias in node.names: if ( alias.name in BANNED_IMPORTS + and alias.name not in self.allowed_imports and "# ignore-banned-import" not in self.lines[node.lineno - 1] ): print( # noqa: T201 f"{self.file_name}:{node.lineno}:{node.col_offset}: found {alias.name} import" ) self.found_import = True + self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802