Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing an efficient Argsort #1145

Closed
PABannier opened this issue Jan 9, 2022 · 5 comments
Closed

Implementing an efficient Argsort #1145

PABannier opened this issue Jan 9, 2022 · 5 comments

Comments

@PABannier
Copy link

PABannier commented Jan 9, 2022

I'm currently working on a project where I need to have a argsort function (in descending order).

// kkt is an instance of Array1<T>

let mut kkt_with_indices: Vec<(usize, T)> = kkt.iter().copied().enumerate().collect();

kkt_with_indices.sort_unstable_by(|(_, p), (_, q)| {
    // Swapped order for sorting in descending order.
    q.partial_cmp(p).expect("kkt must not be NaN.")
});

let ws: Vec<usize> = kkt_with_indices
    .iter()
    .map(|&(ind, _)| ind)
    .take(ws_size)
    .collect();

My implementation works but I think it could be further optimized resorting only to Array, and not passing by Vec, which creates extra memory allocation. In my case this piece of code is very often (possibly 100'000's times up to a million times), so coming up with an efficient argsort function would be amazing.

I've seen an open topic on sorting (#195), but did not find an implementation for argsort. I'd like to implement it as a first contribution to the library, but I need some guidance. Would anybody be willing to help by offering some guidance?

@PABannier PABannier changed the title Efficient argsort Implementing an efficient Argsort Jan 9, 2022
@jturner314
Copy link
Member

jturner314 commented Jan 9, 2022

Here's a simple argsort for &[T]:

pub fn argsort<T>(slice: &[T]) -> Vec<usize>
where
    T: Ord,
{
    let mut indices: Vec<usize> = (0..slice.len()).collect();
    indices.sort_unstable_by_key(|&index| &slice[index]);
    indices
}

A simple implementation for ArrayBase could be written similarly:

use ndarray::prelude::*;
use ndarray::Data;

pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
    S: Data,
    S::Elem: Ord,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    indices.sort_unstable_by_key(|&index| &arr[index]);
    indices
}

It may be faster to have special cases if the array is contiguous:

use ndarray::prelude::*;
use ndarray::Data;

pub fn argsort<S>(arr: &ArrayBase<S, Ix1>) -> Vec<usize>
where
    S: Data,
    S::Elem: Ord,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    if let Some(slice) = arr.as_slice() {
        indices.sort_unstable_by_key(|&index| &slice[index]);
    } else {
        let mut inverted = arr.view();
        inverted.invert_axis(Axis(0));
        if let Some(inv_slice) = inverted.as_slice() {
            indices.sort_unstable_by(|&i, &j| inv_slice[i].cmp(&inv_slice[j]).reverse());
        } else {
            indices.sort_unstable_by_key(|&index| &arr[index]);
        }
    }
    indices
}

I suspect that the compiler won't be able to eliminate the bounds checks by itself, so it may be faster to switch to unchecked indexing (.uget() for ArrayBase and .get_unchecked() for &[T]).

An argsort method would be a good addition to the Sort1dExt trait in the ndarray-stats crate.

@PABannier
Copy link
Author

What if the T generics does not implement the Ord trait? In my code above, T has the Float trait specifically to support f32 and f64 types.

@jturner314
Copy link
Member

By the way, I just remembered rust-ndarray/ndarray-stats#84, which may be of interest. (An argsort implementation specifically for 1-D arrays would be faster than the generic-dimensional implementation in that PR, though.)

What if the T generics does not implement the Ord trait? In my code above, T has the Float trait specifically to support f32 and f64 types.

It depends on how you want to handle NaNs. If you know there aren't any NaNs, then you could use a wrapper type which implements Ord, such as the one provided by the noisy_float crate. A more general approach is to add argsort_by and argsort_by_key functions which accept closures so that the user can specify how to compare elements, e.g.:

use ndarray::prelude::*;
use ndarray::Data;
use std::cmp::Ordering;

pub fn argsort_by<S, F>(arr: &ArrayBase<S, Ix1>, mut compare: F) -> Vec<usize>
where
    S: Data,
    F: FnMut(&S::Elem, &S::Elem) -> Ordering,
{
    let mut indices: Vec<usize> = (0..arr.len()).collect();
    indices.sort_unstable_by(move |&i, &j| compare(&arr[i], &arr[j]));
    indices
}

fn main() {
    let arr = array![3., 0., 2., 1.];
    assert_eq!(
        argsort_by(&arr, |a, b| a
            .partial_cmp(b)
            .expect("Elements must not be NaN.")),
        vec![1, 3, 2, 0],
    );
}

@PABannier
Copy link
Author

Thanks a lot. I think, it would a very valuable addition to the crate.

@Kastakin
Copy link

As of Rust 1.62.0 the total_cmp method can be used to deal with NaNs aswell.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants