diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs
index c0c145d26c63..ede73ea5f87c 100644
--- a/crates/polars-core/src/datatypes/dtype.rs
+++ b/crates/polars-core/src/datatypes/dtype.rs
@@ -364,6 +364,10 @@ impl DataType {
return Some(true);
}
+ if self.is_null() {
+ return Some(true);
+ }
+
use DataType as D;
Some(match (self, to) {
#[cfg(feature = "dtype-categorical")]
diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml
index 8a7fe78bea48..ba4fedcdbb8e 100644
--- a/crates/polars-lazy/Cargo.toml
+++ b/crates/polars-lazy/Cargo.toml
@@ -247,6 +247,7 @@ string_pad = ["polars-plan/string_pad"]
string_reverse = ["polars-plan/string_reverse"]
string_to_integer = ["polars-plan/string_to_integer"]
arg_where = ["polars-plan/arg_where"]
+index_of = ["polars-plan/index_of"]
search_sorted = ["polars-plan/search_sorted"]
merge_sorted = ["polars-plan/merge_sorted", "polars-stream?/merge_sorted"]
meta = ["polars-plan/meta"]
@@ -314,6 +315,7 @@ test_all = [
"row_hash",
"string_pad",
"string_to_integer",
+ "index_of",
"search_sorted",
"top_k",
"pivot",
@@ -360,6 +362,7 @@ features = [
"fused",
"futures",
"hist",
+ "index_of",
"interpolate",
"interpolate_by",
"ipc",
diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml
index c3049bc9403d..7c0586ac0e6c 100644
--- a/crates/polars-ops/Cargo.toml
+++ b/crates/polars-ops/Cargo.toml
@@ -114,6 +114,7 @@ rolling_window = ["polars-core/rolling_window"]
rolling_window_by = ["polars-core/rolling_window_by"]
moment = []
mode = []
+index_of = []
search_sorted = []
merge_sorted = []
top_k = []
diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs
new file mode 100644
index 000000000000..6c2536e263ec
--- /dev/null
+++ b/crates/polars-ops/src/series/ops/index_of.rs
@@ -0,0 +1,121 @@
+use arrow::array::{BinaryArray, PrimitiveArray};
+use polars_core::downcast_as_macro_arg_physical;
+use polars_core::prelude::*;
+use polars_utils::total_ord::TotalEq;
+use row_encode::encode_rows_unordered;
+
+/// Find the index of the value, or ``None`` if it can't be found.
+fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray
, value: AR::ValueT<'a>) -> Option
+where
+ DT: PolarsDataType,
+ AR: StaticArray,
+ AR::ValueT<'a>: TotalEq,
+{
+ let req_value = &value;
+ let mut index = 0;
+ for chunk in ca.chunks() {
+ let chunk = chunk.as_any().downcast_ref::().unwrap();
+ if chunk.validity().is_some() {
+ for maybe_value in chunk.iter() {
+ if maybe_value.map(|v| v.tot_eq(req_value)) == Some(true) {
+ return Some(index);
+ } else {
+ index += 1;
+ }
+ }
+ } else {
+ // A lack of a validity bitmap means there are no nulls, so we
+ // can simplify our logic and use a faster code path:
+ for value in chunk.values_iter() {
+ if value.tot_eq(req_value) {
+ return Some(index);
+ } else {
+ index += 1;
+ }
+ }
+ }
+ }
+ None
+}
+
+fn index_of_numeric_value(ca: &ChunkedArray, value: T::Native) -> Option
+where
+ T: PolarsNumericType,
+{
+ index_of_value::<_, PrimitiveArray>(ca, value)
+}
+
+/// Try casting the value to the correct type, then call
+/// index_of_numeric_value().
+macro_rules! try_index_of_numeric_ca {
+ ($ca:expr, $value:expr) => {{
+ let ca = $ca;
+ let value = $value;
+ // extract() returns None if casting failed, so consider an extract()
+ // failure as not finding the value. Nulls should have been handled
+ // earlier.
+ let value = value.value().extract().unwrap();
+ index_of_numeric_value(ca, value)
+ }};
+}
+
+/// Find the index of a given value (the first and only entry in `value_series`)
+/// within the series.
+pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> {
+ polars_ensure!(
+ series.dtype() == needle.dtype(),
+ InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}",
+ series.dtype(),
+ needle.dtype(),
+ );
+
+ // Series is null:
+ if series.dtype().is_null() {
+ if needle.is_null() {
+ return Ok((series.len() > 0).then_some(0));
+ } else {
+ return Ok(None);
+ }
+ }
+
+ // Series is not null, and the value is null:
+ if needle.is_null() {
+ let mut index = 0;
+ for chunk in series.chunks() {
+ let length = chunk.len();
+ if let Some(bitmap) = chunk.validity() {
+ let leading_ones = bitmap.leading_ones();
+ if leading_ones < length {
+ return Ok(Some(index + leading_ones));
+ }
+ } else {
+ index += length;
+ }
+ }
+ return Ok(None);
+ }
+
+ if series.dtype().is_primitive_numeric() {
+ return Ok(downcast_as_macro_arg_physical!(
+ series,
+ try_index_of_numeric_ca,
+ needle
+ ));
+ }
+
+ if series.dtype().is_categorical() {
+ // See https://github.com/pola-rs/polars/issues/20318
+ polars_bail!(InvalidOperation: "index_of() on Categoricals is not supported");
+ }
+
+ // For non-numeric dtypes, we convert to row-encoding, which essentially has
+ // us searching the physical representation of the data as a series of
+ // bytes.
+ let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1);
+ let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?;
+ let value = value_as_row_encoded_ca
+ .first()
+ .expect("Shouldn't have nulls in a row-encoded result");
+ let ca = encode_rows_unordered(&[series.clone().into()])?;
+ Ok(index_of_value::<_, BinaryArray>(&ca, value))
+}
diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs
index 6be3b85cb8e2..40254b97f92a 100644
--- a/crates/polars-ops/src/series/ops/mod.rs
+++ b/crates/polars-ops/src/series/ops/mod.rs
@@ -21,6 +21,8 @@ mod floor_divide;
mod fused;
mod horizontal;
mod index;
+#[cfg(feature = "index_of")]
+mod index_of;
mod int_range;
#[cfg(any(feature = "interpolate_by", feature = "interpolate"))]
mod interpolation;
@@ -84,6 +86,8 @@ pub use floor_divide::*;
pub use fused::*;
pub use horizontal::*;
pub use index::*;
+#[cfg(feature = "index_of")]
+pub use index_of::*;
pub use int_range::*;
#[cfg(feature = "interpolate")]
pub use interpolation::interpolate::*;
diff --git a/crates/polars-ops/src/series/ops/search_sorted.rs b/crates/polars-ops/src/series/ops/search_sorted.rs
index b95a33d3bc42..510cdde01340 100644
--- a/crates/polars-ops/src/series/ops/search_sorted.rs
+++ b/crates/polars-ops/src/series/ops/search_sorted.rs
@@ -11,6 +11,7 @@ pub fn search_sorted(
let original_dtype = s.dtype();
if s.dtype().is_categorical() {
+ // See https://github.com/pola-rs/polars/issues/20171
polars_bail!(InvalidOperation: "'search_sorted' is not supported on dtype: {}", s.dtype())
}
diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml
index f6547ed249a7..0bb251c57a95 100644
--- a/crates/polars-plan/Cargo.toml
+++ b/crates/polars-plan/Cargo.toml
@@ -159,6 +159,7 @@ string_pad = ["polars-ops/string_pad"]
string_reverse = ["polars-ops/string_reverse"]
string_to_integer = ["polars-ops/string_to_integer"]
arg_where = []
+index_of = ["polars-ops/index_of"]
search_sorted = ["polars-ops/search_sorted"]
merge_sorted = ["polars-ops/merge_sorted"]
meta = []
@@ -263,6 +264,7 @@ features = [
"find_many",
"string_encoding",
"ipc",
+ "index_of",
"search_sorted",
"unique_counts",
"dtype-u8",
diff --git a/crates/polars-plan/src/dsl/function_expr/index_of.rs b/crates/polars-plan/src/dsl/function_expr/index_of.rs
new file mode 100644
index 000000000000..d396d7065091
--- /dev/null
+++ b/crates/polars-plan/src/dsl/function_expr/index_of.rs
@@ -0,0 +1,61 @@
+use polars_ops::series::index_of as index_of_op;
+
+use super::*;
+
+/// Given two columns, find the index of a value (the second column) within the
+/// first column. Will use binary search if possible, as an optimization.
+pub(super) fn index_of(s: &mut [Column]) -> PolarsResult {
+ let series = if let Column::Scalar(ref sc) = s[0] {
+ // We only care about the first value:
+ &sc.as_single_value_series()
+ } else {
+ s[0].as_materialized_series()
+ };
+
+ let needle_s = &s[1];
+ polars_ensure!(
+ needle_s.len() == 1,
+ InvalidOperation: "needle of `index_of` can only contain a single value, found {} values",
+ needle_s.len()
+ );
+ let needle = Scalar::new(
+ needle_s.dtype().clone(),
+ needle_s.get(0).unwrap().into_static(),
+ );
+
+ let is_sorted_flag = series.is_sorted_flag();
+ let result = match is_sorted_flag {
+ // If the Series is sorted, we can use an optimized binary search to
+ // find the value.
+ IsSorted::Ascending | IsSorted::Descending
+ if !needle.is_null() &&
+ // search_sorted() doesn't support decimals at the moment.
+ !series.dtype().is_decimal() =>
+ {
+ search_sorted(
+ series,
+ needle_s.as_materialized_series(),
+ SearchSortedSide::Left,
+ IsSorted::Descending == is_sorted_flag,
+ )?
+ .get(0)
+ .and_then(|idx| {
+ // search_sorted() gives an index even if it's not an exact
+ // match! So we want to make sure it actually found the value.
+ if series.get(idx as usize).ok()? == needle.as_any_value() {
+ Some(idx as usize)
+ } else {
+ None
+ }
+ })
+ },
+ _ => index_of_op(series, needle)?,
+ };
+
+ let av = match result {
+ None => AnyValue::Null,
+ Some(idx) => AnyValue::from(idx as IdxSize),
+ };
+ let scalar = Scalar::new(IDX_DTYPE, av);
+ Ok(Column::new_scalar(series.name().clone(), scalar, 1))
+}
diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs
index e2eb108bd541..b55a6b0bcf1f 100644
--- a/crates/polars-plan/src/dsl/function_expr/mod.rs
+++ b/crates/polars-plan/src/dsl/function_expr/mod.rs
@@ -34,6 +34,8 @@ mod ewm_by;
mod fill_null;
#[cfg(feature = "fused")]
mod fused;
+#[cfg(feature = "index_of")]
+mod index_of;
mod list;
#[cfg(feature = "log")]
mod log;
@@ -154,6 +156,8 @@ pub enum FunctionExpr {
Hash(u64, u64, u64, u64),
#[cfg(feature = "arg_where")]
ArgWhere,
+ #[cfg(feature = "index_of")]
+ IndexOf,
#[cfg(feature = "search_sorted")]
SearchSorted(SearchSortedSide),
#[cfg(feature = "range")]
@@ -395,6 +399,8 @@ impl Hash for FunctionExpr {
#[cfg(feature = "business")]
Business(f) => f.hash(state),
Pow(f) => f.hash(state),
+ #[cfg(feature = "index_of")]
+ IndexOf => {},
#[cfg(feature = "search_sorted")]
SearchSorted(f) => f.hash(state),
#[cfg(feature = "random")]
@@ -640,6 +646,8 @@ impl Display for FunctionExpr {
Hash(_, _, _, _) => "hash",
#[cfg(feature = "arg_where")]
ArgWhere => "arg_where",
+ #[cfg(feature = "index_of")]
+ IndexOf => "index_of",
#[cfg(feature = "search_sorted")]
SearchSorted(_) => "search_sorted",
#[cfg(feature = "range")]
@@ -929,6 +937,10 @@ impl From for SpecialEq> {
ArgWhere => {
wrap!(arg_where::arg_where)
},
+ #[cfg(feature = "index_of")]
+ IndexOf => {
+ map_as_slice!(index_of::index_of)
+ },
#[cfg(feature = "search_sorted")]
SearchSorted(side) => {
map_as_slice!(search_sorted::search_sorted_impl, side)
diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs
index 1c9265fad841..92e8397f3d04 100644
--- a/crates/polars-plan/src/dsl/function_expr/schema.rs
+++ b/crates/polars-plan/src/dsl/function_expr/schema.rs
@@ -49,6 +49,8 @@ impl FunctionExpr {
Hash(..) => mapper.with_dtype(DataType::UInt64),
#[cfg(feature = "arg_where")]
ArgWhere => mapper.with_dtype(IDX_DTYPE),
+ #[cfg(feature = "index_of")]
+ IndexOf => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "range")]
diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs
index f831cac12c6f..f85e9bd551da 100644
--- a/crates/polars-plan/src/dsl/mod.rs
+++ b/crates/polars-plan/src/dsl/mod.rs
@@ -377,6 +377,22 @@ impl Expr {
)
}
+ #[cfg(feature = "index_of")]
+ /// Find the index of a value.
+ pub fn index_of>(self, element: E) -> Expr {
+ let element = element.into();
+ Expr::Function {
+ input: vec![self, element],
+ function: FunctionExpr::IndexOf,
+ options: FunctionOptions {
+ flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
+ fmt_str: "index_of",
+ cast_options: Some(CastingRules::FirstArgLossless),
+ ..Default::default()
+ },
+ }
+ }
+
#[cfg(feature = "search_sorted")]
/// Find indices where elements should be inserted to maintain order.
pub fn search_sorted>(self, element: E, side: SearchSortedSide) -> Expr {
diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml
index 13a2164ed560..b0ab044862e3 100644
--- a/crates/polars-python/Cargo.toml
+++ b/crates/polars-python/Cargo.toml
@@ -134,6 +134,7 @@ repeat_by = ["polars/repeat_by"]
streaming = ["polars/streaming"]
meta = ["polars/meta"]
+index_of = ["polars/index_of"]
search_sorted = ["polars/search_sorted"]
decompress = ["polars/decompress-fast"]
regex = ["polars/regex"]
@@ -211,6 +212,7 @@ operations = [
"asof_join",
"cross_join",
"pct_change",
+ "index_of",
"search_sorted",
"merge_sorted",
"top_k",
diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs
index fe5fdafdbbb8..d51bd9c9f808 100644
--- a/crates/polars-python/src/expr/general.rs
+++ b/crates/polars-python/src/expr/general.rs
@@ -318,6 +318,11 @@ impl PyExpr {
self.inner.clone().arg_min().into()
}
+ #[cfg(feature = "index_of")]
+ fn index_of(&self, element: Self) -> Self {
+ self.inner.clone().index_of(element.inner).into()
+ }
+
#[cfg(feature = "search_sorted")]
fn search_sorted(&self, element: Self, side: Wrap) -> Self {
self.inner
@@ -325,6 +330,7 @@ impl PyExpr {
.search_sorted(element.inner, side.0)
.into()
}
+
fn gather(&self, idx: Self) -> Self {
self.inner.clone().gather(idx.inner).into()
}
diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
index 99ac9a405509..3553ae83bf89 100644
--- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
+++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
@@ -1099,6 +1099,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult {
("hash", seed, seed_1, seed_2, seed_3).into_py_any(py)
},
FunctionExpr::ArgWhere => ("argwhere",).into_py_any(py),
+ #[cfg(feature = "index_of")]
+ FunctionExpr::IndexOf => ("index_of",).into_py_any(py),
#[cfg(feature = "search_sorted")]
FunctionExpr::SearchSorted(side) => (
"search_sorted",
diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml
index 25a1bcbb0489..bb7fe3c06c9c 100644
--- a/crates/polars/Cargo.toml
+++ b/crates/polars/Cargo.toml
@@ -212,6 +212,7 @@ rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window"]
rolling_window_by = ["polars-core/rolling_window_by", "polars-lazy?/rolling_window_by", "polars-time/rolling_window_by"]
round_series = ["polars-ops/round_series", "polars-lazy?/round_series"]
row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"]
+index_of = ["polars-lazy?/index_of"]
search_sorted = ["polars-lazy?/search_sorted"]
semi_anti_join = ["polars-lazy?/semi_anti_join", "polars-ops/semi_anti_join", "polars-sql?/semi_anti_join"]
sign = ["polars-lazy?/sign"]
diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst
index 4b488ec0d692..76b67746a4fa 100644
--- a/py-polars/docs/source/reference/expressions/computation.rst
+++ b/py-polars/docs/source/reference/expressions/computation.rst
@@ -42,6 +42,7 @@ Computation
Expr.exp
Expr.hash
Expr.hist
+ Expr.index_of
Expr.kurtosis
Expr.log
Expr.log10
diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst
index 8cdb8fe152fa..871fa768eb82 100644
--- a/py-polars/docs/source/reference/series/computation.rst
+++ b/py-polars/docs/source/reference/series/computation.rst
@@ -46,6 +46,7 @@ Computation
Series.first
Series.hash
Series.hist
+ Series.index_of
Series.is_between
Series.kurtosis
Series.last
diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py
index d0e94dd2f454..e6d17b207898 100644
--- a/py-polars/polars/expr/expr.py
+++ b/py-polars/polars/expr/expr.py
@@ -2311,6 +2311,37 @@ def arg_min(self) -> Expr:
"""
return self._from_pyexpr(self._pyexpr.arg_min())
+ def index_of(self, element: IntoExpr) -> Expr:
+ """
+ Get the index of the first occurrence of a value, or ``None`` if it's not found.
+
+ Parameters
+ ----------
+ element
+ Value to find.
+
+ Examples
+ --------
+ >>> df = pl.DataFrame({"a": [1, None, 17]})
+ >>> df.select(
+ ... [
+ ... pl.col("a").index_of(17).alias("seventeen"),
+ ... pl.col("a").index_of(None).alias("null"),
+ ... pl.col("a").index_of(55).alias("fiftyfive"),
+ ... ]
+ ... )
+ shape: (1, 3)
+ ┌───────────┬──────┬───────────┐
+ │ seventeen ┆ null ┆ fiftyfive │
+ │ --- ┆ --- ┆ --- │
+ │ u32 ┆ u32 ┆ u32 │
+ ╞═══════════╪══════╪═══════════╡
+ │ 2 ┆ 1 ┆ null │
+ └───────────┴──────┴───────────┘
+ """
+ element = parse_into_expression(element, str_as_lit=True, list_as_series=False)
+ return self._from_pyexpr(self._pyexpr.index_of(element))
+
def search_sorted(
self, element: IntoExpr | np.ndarray[Any, Any], side: SearchSortedSide = "any"
) -> Expr:
diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py
index 7df33cda0208..03aaabc85c57 100644
--- a/py-polars/polars/series/series.py
+++ b/py-polars/polars/series/series.py
@@ -4769,6 +4769,27 @@ def scatter(
self._s.scatter(indices._s, values._s)
return self
+ def index_of(self, element: IntoExpr) -> int | None:
+ """
+ Get the index of the first occurrence of a value, or ``None`` if it's not found.
+
+ Parameters
+ ----------
+ element
+ Value to find.
+
+ Examples
+ --------
+ >>> s = pl.Series("a", [1, None, 17])
+ >>> s.index_of(17)
+ 2
+ >>> s.index_of(None) # search for a null
+ 1
+ >>> s.index_of(55) is None
+ True
+ """
+ return F.select(F.lit(self).index_of(element)).item()
+
def clear(self, n: int = 0) -> Series:
"""
Create an empty copy of the current Series, with zero to 'n' elements.
diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py
new file mode 100644
index 000000000000..bb72e8afdfe7
--- /dev/null
+++ b/py-polars/tests/unit/operations/test_index_of.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from typing import TYPE_CHECKING, Any
+
+import numpy as np
+import pytest
+from hypothesis import example, given
+from hypothesis import strategies as st
+
+import polars as pl
+from polars.exceptions import InvalidOperationError
+
+if TYPE_CHECKING:
+ from polars._typing import IntoExpr
+from polars.testing import assert_frame_equal
+
+
+def assert_index_of(
+ series: pl.Series,
+ value: IntoExpr,
+ convert_to_literal: bool = False,
+) -> None:
+ """``Series.index_of()`` returns the index, or ``None`` if it can't be found."""
+ if isinstance(value, (np.number, float, int)) and np.isnan(value):
+ expected_index = None
+ for i, o in enumerate(series.to_list()):
+ if o is not None and np.isnan(o):
+ expected_index = i
+ break
+ else:
+ try:
+ expected_index = series.to_list().index(value)
+ except ValueError:
+ expected_index = None
+ if expected_index == -1:
+ expected_index = None
+
+ if convert_to_literal:
+ value = pl.lit(value, dtype=series.dtype)
+
+ # Eager API:
+ assert series.index_of(value) == expected_index
+ # Lazy API:
+ assert pl.LazyFrame({"series": series}).select(
+ pl.col("series").index_of(value)
+ ).collect().get_column("series").to_list() == [expected_index]
+
+
+@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
+def test_float(dtype: pl.DataType) -> None:
+ values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]
+ series = pl.Series(values, dtype=dtype)
+ sorted_series_asc = series.sort(descending=False)
+ sorted_series_desc = series.sort(descending=True)
+ chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)
+
+ extra_values = [
+ np.int8(3),
+ np.int64(2**42),
+ np.float64(1.5),
+ np.float32(1.5),
+ np.float32(2**37),
+ np.float64(2**100),
+ ]
+ for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
+ for value in values:
+ assert_index_of(s, value, convert_to_literal=True)
+ assert_index_of(s, value, convert_to_literal=False)
+ for value in extra_values: # type: ignore[assignment]
+ assert_index_of(s, value)
+
+ # Explicitly check some extra-tricky edge cases:
+ assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan
+ assert series.index_of(-0.0) == 6 # -0.0 should match 0.0
+
+
+def test_null() -> None:
+ series = pl.Series([None, None], dtype=pl.Null)
+ assert_index_of(series, None)
+
+
+def test_empty() -> None:
+ series = pl.Series([], dtype=pl.Null)
+ assert_index_of(series, None)
+ series = pl.Series([], dtype=pl.Int64)
+ assert_index_of(series, None)
+ assert_index_of(series, 12)
+ assert_index_of(series.sort(descending=True), 12)
+ assert_index_of(series.sort(descending=False), 12)
+
+
+@pytest.mark.parametrize(
+ "dtype",
+ [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64],
+)
+def test_integer(dtype: pl.DataType) -> None:
+ values = [51, 3, None, 4]
+ series = pl.Series(values, dtype=dtype)
+ sorted_series_asc = series.sort(descending=False)
+ sorted_series_desc = series.sort(descending=True)
+ chunked_series = pl.concat(
+ [pl.Series([100, 7], dtype=dtype), series], rechunk=False
+ )
+
+ extra_values = [pl.select(v).item() for v in [dtype.max(), dtype.min()]] # type: ignore[attr-defined]
+ for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
+ value: IntoExpr
+ for value in values:
+ assert_index_of(s, value, convert_to_literal=True)
+ assert_index_of(s, value, convert_to_literal=False)
+ for value in extra_values:
+ assert_index_of(s, value, convert_to_literal=True)
+ assert_index_of(s, value, convert_to_literal=False)
+
+ # Can't cast floats:
+ for f in [np.float32(3.1), np.float64(3.1), 50.9]:
+ with pytest.raises(InvalidOperationError, match="cannot cast lossless"):
+ s.index_of(f) # type: ignore[arg-type]
+
+
+def test_groupby() -> None:
+ df = pl.DataFrame(
+ {"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}
+ )
+ expected = pl.DataFrame(
+ {"label": ["a", "b"], "value": [1, 2]},
+ schema={"label": pl.String, "value": pl.UInt32},
+ )
+ assert_frame_equal(
+ df.group_by("label", maintain_order=True).agg(pl.col("value").index_of(20)),
+ expected,
+ )
+ assert_frame_equal(
+ df.lazy()
+ .group_by("label", maintain_order=True)
+ .agg(pl.col("value").index_of(20))
+ .collect(),
+ expected,
+ )
+
+
+LISTS_STRATEGY = st.lists(
+ st.one_of(st.none(), st.integers(min_value=10, max_value=50)), max_size=10
+)
+
+
+@given(
+ list1=LISTS_STRATEGY,
+ list2=LISTS_STRATEGY,
+ list3=LISTS_STRATEGY,
+)
+# The examples are cases where this test previously caught bugs:
+@example([], [], [None])
+def test_randomized(
+ list1: list[int | None], list2: list[int | None], list3: list[int | None]
+) -> None:
+ series = pl.concat(
+ [pl.Series(values, dtype=pl.Int8) for values in [list1, list2, list3]],
+ rechunk=False,
+ )
+ sorted_series = series.sort(descending=False)
+ sorted_series2 = series.sort(descending=True)
+
+ # Values are between 10 and 50, plus add None and max/min range values:
+ for i in set(range(10, 51)) | {-128, 127, None}:
+ assert_index_of(series, i)
+ assert_index_of(sorted_series, i)
+ assert_index_of(sorted_series2, i)
+
+
+ENUM = pl.Enum(["a", "b", "c"])
+
+
+@pytest.mark.parametrize(
+ ("series", "extra_values", "sortable"),
+ [
+ (pl.Series(["abc", None, "bb"]), ["", "🚲"], True),
+ (pl.Series([True, None, False, True, False]), [], True),
+ (
+ pl.Series([datetime(1997, 12, 31), datetime(1996, 1, 1)]),
+ [datetime(2023, 12, 12, 16, 12, 39)],
+ True,
+ ),
+ (
+ pl.Series([date(1997, 12, 31), None, date(1996, 1, 1)]),
+ [date(2023, 12, 12)],
+ True,
+ ),
+ (
+ pl.Series([time(16, 12, 31), None, time(11, 10, 53)]),
+ [time(11, 12, 16)],
+ True,
+ ),
+ (
+ pl.Series(
+ [timedelta(hours=12), None, timedelta(minutes=3)],
+ ),
+ [timedelta(minutes=17)],
+ True,
+ ),
+ (pl.Series([[1, 2], None, [4, 5], [6], [None, 3, 5]]), [[5, 7], []], True),
+ (
+ pl.Series([[[1, 2]], None, [[4, 5]], [[6]], [[None, 3, 5]], [None]]),
+ [[[5, 7]], []],
+ True,
+ ),
+ (
+ pl.Series([[1, 2], None, [4, 5], [None, 3]], dtype=pl.Array(pl.Int64(), 2)),
+ [[5, 7], [None, None]],
+ True,
+ ),
+ (
+ pl.Series(
+ [[[1, 2]], [None], [[4, 5]], None, [[None, 3]]],
+ dtype=pl.Array(pl.Array(pl.Int64(), 2), 1),
+ ),
+ [[[5, 7]], [[None, None]]],
+ True,
+ ),
+ (
+ pl.Series(
+ [{"a": 1, "b": 2}, None, {"a": 3, "b": 4}, {"a": None, "b": 2}],
+ dtype=pl.Struct({"a": pl.Int64(), "b": pl.Int64()}),
+ ),
+ [{"a": 7, "b": None}, {"a": 6, "b": 4}],
+ False,
+ ),
+ (pl.Series([b"abc", None, b"xxx"]), [b"\x0025"], True),
+ (pl.Series([Decimal(12), None, Decimal(3)]), [Decimal(4)], True),
+ ],
+)
+def test_other_types(
+ series: pl.Series, extra_values: list[Any], sortable: bool
+) -> None:
+ expected_values = series.to_list()
+ series_variants = [series, series.drop_nulls()]
+ if sortable:
+ series_variants.extend(
+ [
+ series.sort(descending=False),
+ series.sort(descending=True),
+ ]
+ )
+ for s in series_variants:
+ for value in expected_values:
+ assert_index_of(s, value, convert_to_literal=True)
+ assert_index_of(s, value, convert_to_literal=False)
+ # Extra values may not be expressible as literal of correct dtype, so
+ # don't try:
+ for value in extra_values:
+ assert_index_of(s, value)
+
+
+# Before the output type would be list[idx-type] when no item was found
+def test_non_found_correct_type() -> None:
+ df = pl.DataFrame(
+ [
+ pl.Series("a", [0, 1], pl.Int32),
+ pl.Series("b", [1, 2], pl.Int32),
+ ]
+ )
+
+ assert_frame_equal(
+ df.group_by("a", maintain_order=True).agg(pl.col.b.index_of(1)),
+ pl.DataFrame({"a": [0, 1], "b": [0, None]}),
+ check_dtypes=False,
+ )
+
+
+def test_error_on_multiple_values() -> None:
+ with pytest.raises(
+ pl.exceptions.InvalidOperationError,
+ match="needle of `index_of` can only contain",
+ ):
+ pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3]))
+
+
+@pytest.mark.parametrize(
+ "convert_to_literal",
+ [
+ True,
+ False,
+ ],
+)
+def test_enum(convert_to_literal: bool) -> None:
+ series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"]))
+ expected_values = series.to_list()
+ for s in [
+ series,
+ series.drop_nulls(),
+ series.sort(descending=False),
+ series.sort(descending=True),
+ ]:
+ for value in expected_values:
+ assert_index_of(s, value, convert_to_literal=convert_to_literal)
+
+
+@pytest.mark.parametrize(
+ "convert_to_literal",
+ [
+ pytest.param(
+ True,
+ marks=pytest.mark.xfail(
+ reason="https://github.com/pola-rs/polars/issues/20318"
+ ),
+ ),
+ pytest.param(
+ False,
+ marks=pytest.mark.xfail(
+ reason="https://github.com/pola-rs/polars/issues/20171"
+ ),
+ ),
+ ],
+)
+def test_categorical(convert_to_literal: bool) -> None:
+ series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
+ expected_values = series.to_list()
+ for s in [
+ series,
+ series.drop_nulls(),
+ series.sort(descending=False),
+ series.sort(descending=True),
+ ]:
+ for value in expected_values:
+ assert_index_of(s, value, convert_to_literal=convert_to_literal)