Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize WCond and ReduceIndex implementations for improved performance and safety #2648

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 67 additions & 54 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! Implementation of Backend Fns for CPU
use std::mem::MaybeUninit;

use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
Expand Down Expand Up @@ -70,33 +72,39 @@ impl<I: IntDType> Map2 for WCond<'_, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
let vs = match (
let contig_offsets = (
self.1.contiguous_offsets(),
t_l.contiguous_offsets(),
f_l.contiguous_offsets(),
) {
);

let vs: Vec<T> = match contig_offsets {
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
let pred = &self.0[o1..o2];
let t = &t[o_t1..o_t2];
let f = &f[o_f1..o_f2];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(p, (&t, &f))| if p.is_true() { t } else { f })
.collect::<Vec<_>>()
}
_ => self
.1
.strided_index()
.zip(t_l.strided_index().zip(f_l.strided_index()))
.map(|(i_p, (i_t, i_f))| {
if self.0[i_p].is_true() {
t[i_t]
} else {
f[i_f]
}
})
.collect::<Vec<_>>(),
.collect()
}
_ => {
let pred_indices = self.1.strided_index();
let t_indices = t_l.strided_index();
let f_indices = f_l.strided_index();
pred_indices
.zip(t_indices.zip(f_indices))
.map(|(i_p, (i_t, i_f))| {
if self.0[i_p].is_true() {
t[i_t]
} else {
f[i_f]
}
})
.collect()
}
};

Ok(vs)
}
}
Expand All @@ -106,81 +114,86 @@ struct ReduceIndex {
use_min: bool,
return_index: bool,
}

impl ReduceIndex {
// The value gets replaced if f(s[current_acc], s[i]) returns true.
// The value gets replaced if `f(s[current_acc], s[i])` returns true.
#[inline(always)]
fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
where
T: Clone + Copy,
U: Clone + Copy,
T: Copy,
U: Copy,
F: Fn(T, T) -> bool,
G: Fn(T, usize) -> U,
{
let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut();
let dst_to_set =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };

let dst_ptr = dst.spare_capacity_mut().as_mut_ptr();

match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
if reduce_dim_stride == 1 {
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
let start_src_i = start_src_i * reduce_dim_size;
let src = &src[start_src_i..start_src_i + reduce_dim_size];
for i in 0..dst_len {
let start_src_i = i * reduce_dim_size;
let src_slice = &src[start_src_i..start_src_i + reduce_dim_size];

let mut acc = 0;
let mut val = src[0];
for (src_i, &s) in src.iter().enumerate() {
let mut val = src_slice[0];
for (j, &s) in src_slice.iter().enumerate() {
if f(val, s) {
acc = src_i;
val = s
acc = j;
val = s;
}
}
*dst_v = g(val, acc)
unsafe {
dst_ptr.add(i).write(MaybeUninit::new(g(val, acc)));
}
}
} else {
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
let (p, q) = (
start_src_i / reduce_dim_stride,
start_src_i % reduce_dim_stride,
);
// start_src_i = p * reduce_dim_stride + q
let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
let src = &src[start_src_i..];
for i in 0..dst_len {
let p = i / reduce_dim_stride;
let q = i % reduce_dim_stride;
let start_src_i = p * reduce_dim_size * reduce_dim_stride + q;

let mut acc = 0;
let mut val = src[0];
for src_i in 0..reduce_dim_size {
let s = src[src_i * reduce_dim_stride];
let mut val = src[start_src_i];
for j in 0..reduce_dim_size {
let s = src[start_src_i + j * reduce_dim_stride];
if f(val, s) {
acc = src_i;
val = s
acc = j;
val = s;
}
}
*dst_v = g(val, acc)
unsafe {
dst_ptr.add(i).write(MaybeUninit::new(g(val, acc)));
}
}
}
}
None => {
let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
for (unstr_index, src_index) in l.strided_index().enumerate() {
let src = &src[src_index..];
for (i, src_index) in l.strided_index().enumerate() {
let mut acc = 0;
let mut val = src[0];
for src_i in 0..reduce_dim_size {
let s = src[src_i * reduce_dim_stride];
let mut val = src[src_index];
for j in 0..reduce_dim_size {
let s = src[src_index + j * reduce_dim_stride];
if f(val, s) {
acc = src_i;
val = s
acc = j;
val = s;
}
}
dst_to_set[unstr_index] = g(val, acc)
unsafe {
dst_ptr.add(i).write(MaybeUninit::new(g(val, acc)));
}
}
}
}
unsafe { dst.set_len(dst_len) };

unsafe {
dst.set_len(dst_len);
}
Ok(dst)
}
}
Expand Down