diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index c0c145d26c63..ede73ea5f87c 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -364,6 +364,10 @@ impl DataType { return Some(true); } + if self.is_null() { + return Some(true); + } + use DataType as D; Some(match (self, to) { #[cfg(feature = "dtype-categorical")] diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 8a7fe78bea48..ba4fedcdbb8e 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -247,6 +247,7 @@ string_pad = ["polars-plan/string_pad"] string_reverse = ["polars-plan/string_reverse"] string_to_integer = ["polars-plan/string_to_integer"] arg_where = ["polars-plan/arg_where"] +index_of = ["polars-plan/index_of"] search_sorted = ["polars-plan/search_sorted"] merge_sorted = ["polars-plan/merge_sorted", "polars-stream?/merge_sorted"] meta = ["polars-plan/meta"] @@ -314,6 +315,7 @@ test_all = [ "row_hash", "string_pad", "string_to_integer", + "index_of", "search_sorted", "top_k", "pivot", @@ -360,6 +362,7 @@ features = [ "fused", "futures", "hist", + "index_of", "interpolate", "interpolate_by", "ipc", diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index c3049bc9403d..7c0586ac0e6c 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -114,6 +114,7 @@ rolling_window = ["polars-core/rolling_window"] rolling_window_by = ["polars-core/rolling_window_by"] moment = [] mode = [] +index_of = [] search_sorted = [] merge_sorted = [] top_k = [] diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs new file mode 100644 index 000000000000..6c2536e263ec --- /dev/null +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -0,0 +1,121 @@ +use arrow::array::{BinaryArray, PrimitiveArray}; +use polars_core::downcast_as_macro_arg_physical; +use polars_core::prelude::*; +use polars_utils::total_ord::TotalEq; +use row_encode::encode_rows_unordered; + +/// Find the index of the value, or ``None`` if it can't be found. +fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray
, value: AR::ValueT<'a>) -> Option +where + DT: PolarsDataType, + AR: StaticArray, + AR::ValueT<'a>: TotalEq, +{ + let req_value = &value; + let mut index = 0; + for chunk in ca.chunks() { + let chunk = chunk.as_any().downcast_ref::().unwrap(); + if chunk.validity().is_some() { + for maybe_value in chunk.iter() { + if maybe_value.map(|v| v.tot_eq(req_value)) == Some(true) { + return Some(index); + } else { + index += 1; + } + } + } else { + // A lack of a validity bitmap means there are no nulls, so we + // can simplify our logic and use a faster code path: + for value in chunk.values_iter() { + if value.tot_eq(req_value) { + return Some(index); + } else { + index += 1; + } + } + } + } + None +} + +fn index_of_numeric_value(ca: &ChunkedArray, value: T::Native) -> Option +where + T: PolarsNumericType, +{ + index_of_value::<_, PrimitiveArray>(ca, value) +} + +/// Try casting the value to the correct type, then call +/// index_of_numeric_value(). +macro_rules! try_index_of_numeric_ca { + ($ca:expr, $value:expr) => {{ + let ca = $ca; + let value = $value; + // extract() returns None if casting failed, so consider an extract() + // failure as not finding the value. Nulls should have been handled + // earlier. + let value = value.value().extract().unwrap(); + index_of_numeric_value(ca, value) + }}; +} + +/// Find the index of a given value (the first and only entry in `value_series`) +/// within the series. +pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> { + polars_ensure!( + series.dtype() == needle.dtype(), + InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}", + series.dtype(), + needle.dtype(), + ); + + // Series is null: + if series.dtype().is_null() { + if needle.is_null() { + return Ok((series.len() > 0).then_some(0)); + } else { + return Ok(None); + } + } + + // Series is not null, and the value is null: + if needle.is_null() { + let mut index = 0; + for chunk in series.chunks() { + let length = chunk.len(); + if let Some(bitmap) = chunk.validity() { + let leading_ones = bitmap.leading_ones(); + if leading_ones < length { + return Ok(Some(index + leading_ones)); + } + } else { + index += length; + } + } + return Ok(None); + } + + if series.dtype().is_primitive_numeric() { + return Ok(downcast_as_macro_arg_physical!( + series, + try_index_of_numeric_ca, + needle + )); + } + + if series.dtype().is_categorical() { + // See https://github.com/pola-rs/polars/issues/20318 + polars_bail!(InvalidOperation: "index_of() on Categoricals is not supported"); + } + + // For non-numeric dtypes, we convert to row-encoding, which essentially has + // us searching the physical representation of the data as a series of + // bytes. + let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1); + let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?; + let value = value_as_row_encoded_ca + .first() + .expect("Shouldn't have nulls in a row-encoded result"); + let ca = encode_rows_unordered(&[series.clone().into()])?; + Ok(index_of_value::<_, BinaryArray>(&ca, value)) +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 6be3b85cb8e2..40254b97f92a 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -21,6 +21,8 @@ mod floor_divide; mod fused; mod horizontal; mod index; +#[cfg(feature = "index_of")] +mod index_of; mod int_range; #[cfg(any(feature = "interpolate_by", feature = "interpolate"))] mod interpolation; @@ -84,6 +86,8 @@ pub use floor_divide::*; pub use fused::*; pub use horizontal::*; pub use index::*; +#[cfg(feature = "index_of")] +pub use index_of::*; pub use int_range::*; #[cfg(feature = "interpolate")] pub use interpolation::interpolate::*; diff --git a/crates/polars-ops/src/series/ops/search_sorted.rs b/crates/polars-ops/src/series/ops/search_sorted.rs index b95a33d3bc42..510cdde01340 100644 --- a/crates/polars-ops/src/series/ops/search_sorted.rs +++ b/crates/polars-ops/src/series/ops/search_sorted.rs @@ -11,6 +11,7 @@ pub fn search_sorted( let original_dtype = s.dtype(); if s.dtype().is_categorical() { + // See https://github.com/pola-rs/polars/issues/20171 polars_bail!(InvalidOperation: "'search_sorted' is not supported on dtype: {}", s.dtype()) } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index f6547ed249a7..0bb251c57a95 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -159,6 +159,7 @@ string_pad = ["polars-ops/string_pad"] string_reverse = ["polars-ops/string_reverse"] string_to_integer = ["polars-ops/string_to_integer"] arg_where = [] +index_of = ["polars-ops/index_of"] search_sorted = ["polars-ops/search_sorted"] merge_sorted = ["polars-ops/merge_sorted"] meta = [] @@ -263,6 +264,7 @@ features = [ "find_many", "string_encoding", "ipc", + "index_of", "search_sorted", "unique_counts", "dtype-u8", diff --git a/crates/polars-plan/src/dsl/function_expr/index_of.rs b/crates/polars-plan/src/dsl/function_expr/index_of.rs new file mode 100644 index 000000000000..d396d7065091 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/index_of.rs @@ -0,0 +1,61 @@ +use polars_ops::series::index_of as index_of_op; + +use super::*; + +/// Given two columns, find the index of a value (the second column) within the +/// first column. Will use binary search if possible, as an optimization. +pub(super) fn index_of(s: &mut [Column]) -> PolarsResult { + let series = if let Column::Scalar(ref sc) = s[0] { + // We only care about the first value: + &sc.as_single_value_series() + } else { + s[0].as_materialized_series() + }; + + let needle_s = &s[1]; + polars_ensure!( + needle_s.len() == 1, + InvalidOperation: "needle of `index_of` can only contain a single value, found {} values", + needle_s.len() + ); + let needle = Scalar::new( + needle_s.dtype().clone(), + needle_s.get(0).unwrap().into_static(), + ); + + let is_sorted_flag = series.is_sorted_flag(); + let result = match is_sorted_flag { + // If the Series is sorted, we can use an optimized binary search to + // find the value. + IsSorted::Ascending | IsSorted::Descending + if !needle.is_null() && + // search_sorted() doesn't support decimals at the moment. + !series.dtype().is_decimal() => + { + search_sorted( + series, + needle_s.as_materialized_series(), + SearchSortedSide::Left, + IsSorted::Descending == is_sorted_flag, + )? + .get(0) + .and_then(|idx| { + // search_sorted() gives an index even if it's not an exact + // match! So we want to make sure it actually found the value. + if series.get(idx as usize).ok()? == needle.as_any_value() { + Some(idx as usize) + } else { + None + } + }) + }, + _ => index_of_op(series, needle)?, + }; + + let av = match result { + None => AnyValue::Null, + Some(idx) => AnyValue::from(idx as IdxSize), + }; + let scalar = Scalar::new(IDX_DTYPE, av); + Ok(Column::new_scalar(series.name().clone(), scalar, 1)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index e2eb108bd541..b55a6b0bcf1f 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -34,6 +34,8 @@ mod ewm_by; mod fill_null; #[cfg(feature = "fused")] mod fused; +#[cfg(feature = "index_of")] +mod index_of; mod list; #[cfg(feature = "log")] mod log; @@ -154,6 +156,8 @@ pub enum FunctionExpr { Hash(u64, u64, u64, u64), #[cfg(feature = "arg_where")] ArgWhere, + #[cfg(feature = "index_of")] + IndexOf, #[cfg(feature = "search_sorted")] SearchSorted(SearchSortedSide), #[cfg(feature = "range")] @@ -395,6 +399,8 @@ impl Hash for FunctionExpr { #[cfg(feature = "business")] Business(f) => f.hash(state), Pow(f) => f.hash(state), + #[cfg(feature = "index_of")] + IndexOf => {}, #[cfg(feature = "search_sorted")] SearchSorted(f) => f.hash(state), #[cfg(feature = "random")] @@ -640,6 +646,8 @@ impl Display for FunctionExpr { Hash(_, _, _, _) => "hash", #[cfg(feature = "arg_where")] ArgWhere => "arg_where", + #[cfg(feature = "index_of")] + IndexOf => "index_of", #[cfg(feature = "search_sorted")] SearchSorted(_) => "search_sorted", #[cfg(feature = "range")] @@ -929,6 +937,10 @@ impl From for SpecialEq> { ArgWhere => { wrap!(arg_where::arg_where) }, + #[cfg(feature = "index_of")] + IndexOf => { + map_as_slice!(index_of::index_of) + }, #[cfg(feature = "search_sorted")] SearchSorted(side) => { map_as_slice!(search_sorted::search_sorted_impl, side) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 1c9265fad841..92e8397f3d04 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -49,6 +49,8 @@ impl FunctionExpr { Hash(..) => mapper.with_dtype(DataType::UInt64), #[cfg(feature = "arg_where")] ArgWhere => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "index_of")] + IndexOf => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "search_sorted")] SearchSorted(_) => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "range")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index f831cac12c6f..f85e9bd551da 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -377,6 +377,22 @@ impl Expr { ) } + #[cfg(feature = "index_of")] + /// Find the index of a value. + pub fn index_of>(self, element: E) -> Expr { + let element = element.into(); + Expr::Function { + input: vec![self, element], + function: FunctionExpr::IndexOf, + options: FunctionOptions { + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "index_of", + cast_options: Some(CastingRules::FirstArgLossless), + ..Default::default() + }, + } + } + #[cfg(feature = "search_sorted")] /// Find indices where elements should be inserted to maintain order. pub fn search_sorted>(self, element: E, side: SearchSortedSide) -> Expr { diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 13a2164ed560..b0ab044862e3 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -134,6 +134,7 @@ repeat_by = ["polars/repeat_by"] streaming = ["polars/streaming"] meta = ["polars/meta"] +index_of = ["polars/index_of"] search_sorted = ["polars/search_sorted"] decompress = ["polars/decompress-fast"] regex = ["polars/regex"] @@ -211,6 +212,7 @@ operations = [ "asof_join", "cross_join", "pct_change", + "index_of", "search_sorted", "merge_sorted", "top_k", diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index fe5fdafdbbb8..d51bd9c9f808 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -318,6 +318,11 @@ impl PyExpr { self.inner.clone().arg_min().into() } + #[cfg(feature = "index_of")] + fn index_of(&self, element: Self) -> Self { + self.inner.clone().index_of(element.inner).into() + } + #[cfg(feature = "search_sorted")] fn search_sorted(&self, element: Self, side: Wrap) -> Self { self.inner @@ -325,6 +330,7 @@ impl PyExpr { .search_sorted(element.inner, side.0) .into() } + fn gather(&self, idx: Self) -> Self { self.inner.clone().gather(idx.inner).into() } diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 99ac9a405509..3553ae83bf89 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1099,6 +1099,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { ("hash", seed, seed_1, seed_2, seed_3).into_py_any(py) }, FunctionExpr::ArgWhere => ("argwhere",).into_py_any(py), + #[cfg(feature = "index_of")] + FunctionExpr::IndexOf => ("index_of",).into_py_any(py), #[cfg(feature = "search_sorted")] FunctionExpr::SearchSorted(side) => ( "search_sorted", diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 25a1bcbb0489..bb7fe3c06c9c 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -212,6 +212,7 @@ rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window"] rolling_window_by = ["polars-core/rolling_window_by", "polars-lazy?/rolling_window_by", "polars-time/rolling_window_by"] round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] +index_of = ["polars-lazy?/index_of"] search_sorted = ["polars-lazy?/search_sorted"] semi_anti_join = ["polars-lazy?/semi_anti_join", "polars-ops/semi_anti_join", "polars-sql?/semi_anti_join"] sign = ["polars-lazy?/sign"] diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst index 4b488ec0d692..76b67746a4fa 100644 --- a/py-polars/docs/source/reference/expressions/computation.rst +++ b/py-polars/docs/source/reference/expressions/computation.rst @@ -42,6 +42,7 @@ Computation Expr.exp Expr.hash Expr.hist + Expr.index_of Expr.kurtosis Expr.log Expr.log10 diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 8cdb8fe152fa..871fa768eb82 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -46,6 +46,7 @@ Computation Series.first Series.hash Series.hist + Series.index_of Series.is_between Series.kurtosis Series.last diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index d0e94dd2f454..e6d17b207898 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2311,6 +2311,37 @@ def arg_min(self) -> Expr: """ return self._from_pyexpr(self._pyexpr.arg_min()) + def index_of(self, element: IntoExpr) -> Expr: + """ + Get the index of the first occurrence of a value, or ``None`` if it's not found. + + Parameters + ---------- + element + Value to find. + + Examples + -------- + >>> df = pl.DataFrame({"a": [1, None, 17]}) + >>> df.select( + ... [ + ... pl.col("a").index_of(17).alias("seventeen"), + ... pl.col("a").index_of(None).alias("null"), + ... pl.col("a").index_of(55).alias("fiftyfive"), + ... ] + ... ) + shape: (1, 3) + ┌───────────┬──────┬───────────┐ + │ seventeen ┆ null ┆ fiftyfive │ + │ --- ┆ --- ┆ --- │ + │ u32 ┆ u32 ┆ u32 │ + ╞═══════════╪══════╪═══════════╡ + │ 2 ┆ 1 ┆ null │ + └───────────┴──────┴───────────┘ + """ + element = parse_into_expression(element, str_as_lit=True, list_as_series=False) + return self._from_pyexpr(self._pyexpr.index_of(element)) + def search_sorted( self, element: IntoExpr | np.ndarray[Any, Any], side: SearchSortedSide = "any" ) -> Expr: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 7df33cda0208..03aaabc85c57 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4769,6 +4769,27 @@ def scatter( self._s.scatter(indices._s, values._s) return self + def index_of(self, element: IntoExpr) -> int | None: + """ + Get the index of the first occurrence of a value, or ``None`` if it's not found. + + Parameters + ---------- + element + Value to find. + + Examples + -------- + >>> s = pl.Series("a", [1, None, 17]) + >>> s.index_of(17) + 2 + >>> s.index_of(None) # search for a null + 1 + >>> s.index_of(55) is None + True + """ + return F.select(F.lit(self).index_of(element)).item() + def clear(self, n: int = 0) -> Series: """ Create an empty copy of the current Series, with zero to 'n' elements. diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py new file mode 100644 index 000000000000..bb72e8afdfe7 --- /dev/null +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest +from hypothesis import example, given +from hypothesis import strategies as st + +import polars as pl +from polars.exceptions import InvalidOperationError + +if TYPE_CHECKING: + from polars._typing import IntoExpr +from polars.testing import assert_frame_equal + + +def assert_index_of( + series: pl.Series, + value: IntoExpr, + convert_to_literal: bool = False, +) -> None: + """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" + if isinstance(value, (np.number, float, int)) and np.isnan(value): + expected_index = None + for i, o in enumerate(series.to_list()): + if o is not None and np.isnan(o): + expected_index = i + break + else: + try: + expected_index = series.to_list().index(value) + except ValueError: + expected_index = None + if expected_index == -1: + expected_index = None + + if convert_to_literal: + value = pl.lit(value, dtype=series.dtype) + + # Eager API: + assert series.index_of(value) == expected_index + # Lazy API: + assert pl.LazyFrame({"series": series}).select( + pl.col("series").index_of(value) + ).collect().get_column("series").to_list() == [expected_index] + + +@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) +def test_float(dtype: pl.DataType) -> None: + values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan] + series = pl.Series(values, dtype=dtype) + sorted_series_asc = series.sort(descending=False) + sorted_series_desc = series.sort(descending=True) + chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False) + + extra_values = [ + np.int8(3), + np.int64(2**42), + np.float64(1.5), + np.float32(1.5), + np.float32(2**37), + np.float64(2**100), + ] + for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: + for value in values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + for value in extra_values: # type: ignore[assignment] + assert_index_of(s, value) + + # Explicitly check some extra-tricky edge cases: + assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan + assert series.index_of(-0.0) == 6 # -0.0 should match 0.0 + + +def test_null() -> None: + series = pl.Series([None, None], dtype=pl.Null) + assert_index_of(series, None) + + +def test_empty() -> None: + series = pl.Series([], dtype=pl.Null) + assert_index_of(series, None) + series = pl.Series([], dtype=pl.Int64) + assert_index_of(series, None) + assert_index_of(series, 12) + assert_index_of(series.sort(descending=True), 12) + assert_index_of(series.sort(descending=False), 12) + + +@pytest.mark.parametrize( + "dtype", + [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64], +) +def test_integer(dtype: pl.DataType) -> None: + values = [51, 3, None, 4] + series = pl.Series(values, dtype=dtype) + sorted_series_asc = series.sort(descending=False) + sorted_series_desc = series.sort(descending=True) + chunked_series = pl.concat( + [pl.Series([100, 7], dtype=dtype), series], rechunk=False + ) + + extra_values = [pl.select(v).item() for v in [dtype.max(), dtype.min()]] # type: ignore[attr-defined] + for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: + value: IntoExpr + for value in values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + for value in extra_values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + + # Can't cast floats: + for f in [np.float32(3.1), np.float64(3.1), 50.9]: + with pytest.raises(InvalidOperationError, match="cannot cast lossless"): + s.index_of(f) # type: ignore[arg-type] + + +def test_groupby() -> None: + df = pl.DataFrame( + {"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]} + ) + expected = pl.DataFrame( + {"label": ["a", "b"], "value": [1, 2]}, + schema={"label": pl.String, "value": pl.UInt32}, + ) + assert_frame_equal( + df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)), + expected, + ) + assert_frame_equal( + df.lazy() + .group_by("label", maintain_order=True) + .agg(pl.col("value").index_of(20)) + .collect(), + expected, + ) + + +LISTS_STRATEGY = st.lists( + st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10 +) + + +@given( + list1=LISTS_STRATEGY, + list2=LISTS_STRATEGY, + list3=LISTS_STRATEGY, +) +# The examples are cases where this test previously caught bugs: +@example([], [], [None]) +def test_randomized( + list1: list[int | None], list2: list[int | None], list3: list[int | None] +) -> None: + series = pl.concat( + [pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]], + rechunk=False, + ) + sorted_series = series.sort(descending=False) + sorted_series2 = series.sort(descending=True) + + # Values are between 10 and 50, plus add None and max/min range values: + for i in set(range(10, 51)) | {-128, 127, None}: + assert_index_of(series, i) + assert_index_of(sorted_series, i) + assert_index_of(sorted_series2, i) + + +ENUM = pl.Enum(["a", "b", "c"]) + + +@pytest.mark.parametrize( + ("series", "extra_values", "sortable"), + [ + (pl.Series(["abc", None, "bb"]), ["", "🚲"], True), + (pl.Series([True, None, False, True, False]), [], True), + ( + pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]), + [datetime(2023, 12, 12, 16, 12, 39)], + True, + ), + ( + pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]), + [date(2023, 12, 12)], + True, + ), + ( + pl.Series([time(16, 12, 31), None, time(11, 10, 53)]), + [time(11, 12, 16)], + True, + ), + ( + pl.Series( + [timedelta(hours=12), None, timedelta(minutes=3)], + ), + [timedelta(minutes=17)], + True, + ), + (pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True), + ( + pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]), + [[[5, 7]], []], + True, + ), + ( + pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)), + [[5, 7], [None, None]], + True, + ), + ( + pl.Series( + [[[1, 2]], [None], [[4, 5]], None, [[None, 3]]], + dtype=pl.Array(pl.Array(pl.Int64(), 2), 1), + ), + [[[5, 7]], [[None, None]]], + True, + ), + ( + pl.Series( + [{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}], + dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}), + ), + [{"a": 7, "b": None}, {"a": 6, "b": 4}], + False, + ), + (pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True), + (pl.Series([Decimal(12), None, Decimal(3)]), [Decimal(4)], True), + ], +) +def test_other_types( + series: pl.Series, extra_values: list[Any], sortable: bool +) -> None: + expected_values = series.to_list() + series_variants = [series, series.drop_nulls()] + if sortable: + series_variants.extend( + [ + series.sort(descending=False), + series.sort(descending=True), + ] + ) + for s in series_variants: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=True) + assert_index_of(s, value, convert_to_literal=False) + # Extra values may not be expressible as literal of correct dtype, so + # don't try: + for value in extra_values: + assert_index_of(s, value) + + +# Before the output type would be list[idx-type] when no item was found +def test_non_found_correct_type() -> None: + df = pl.DataFrame( + [ + pl.Series("a", [0, 1], pl.Int32), + pl.Series("b", [1, 2], pl.Int32), + ] + ) + + assert_frame_equal( + df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)), + pl.DataFrame({"a": [0, 1], "b": [0, None]}), + check_dtypes=False, + ) + + +def test_error_on_multiple_values() -> None: + with pytest.raises( + pl.exceptions.InvalidOperationError, + match="needle of `index_of` can only contain", + ): + pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3])) + + +@pytest.mark.parametrize( + "convert_to_literal", + [ + True, + False, + ], +) +def test_enum(convert_to_literal: bool) -> None: + series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"])) + expected_values = series.to_list() + for s in [ + series, + series.drop_nulls(), + series.sort(descending=False), + series.sort(descending=True), + ]: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=convert_to_literal) + + +@pytest.mark.parametrize( + "convert_to_literal", + [ + pytest.param( + True, + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/20318" + ), + ), + pytest.param( + False, + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/20171" + ), + ), + ], +) +def test_categorical(convert_to_literal: bool) -> None: + series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical) + expected_values = series.to_list() + for s in [ + series, + series.drop_nulls(), + series.sort(descending=False), + series.sort(descending=True), + ]: + for value in expected_values: + assert_index_of(s, value, convert_to_literal=convert_to_literal)