From 6dceaaaa55c24e30f030840f8d9d70df27142213 Mon Sep 17 00:00:00 2001 From: pluveto Date: Fri, 29 Nov 2024 12:06:09 +0800 Subject: [PATCH] Optimize WCond and ReduceIndex implementations for improved performance and safety --- candle-core/src/cpu_backend/mod.rs | 121 ++++++++++++++++------------- 1 file changed, 67 insertions(+), 54 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 11ff1a406..70643d029 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1,4 +1,6 @@ //! Implementation of Backend Fns for CPU +use std::mem::MaybeUninit; + use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; @@ -70,11 +72,13 @@ impl Map2 for WCond<'_, I> { const OP: &'static str = "where"; #[inline(always)] fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> { - let vs = match ( + let contig_offsets = ( self.1.contiguous_offsets(), t_l.contiguous_offsets(), f_l.contiguous_offsets(), - ) { + ); + + let vs: Vec = match contig_offsets { (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { let pred = &self.0[o1..o2]; let t = &t[o_t1..o_t2]; @@ -82,21 +86,25 @@ impl Map2 for WCond<'_, I> { pred.iter() .zip(t.iter().zip(f.iter())) .map(|(p, (&t, &f))| if p.is_true() { t } else { f }) - .collect::>() - } - _ => self - .1 - .strided_index() - .zip(t_l.strided_index().zip(f_l.strided_index())) - .map(|(i_p, (i_t, i_f))| { - if self.0[i_p].is_true() { - t[i_t] - } else { - f[i_f] - } - }) - .collect::>(), + .collect() + } + _ => { + let pred_indices = self.1.strided_index(); + let t_indices = t_l.strided_index(); + let f_indices = f_l.strided_index(); + pred_indices + .zip(t_indices.zip(f_indices)) + .map(|(i_p, (i_t, i_f))| { + if self.0[i_p].is_true() { + t[i_t] + } else { + f[i_f] + } + }) + .collect() + } }; + Ok(vs) } } @@ -106,14 +114,13 @@ struct ReduceIndex { use_min: bool, return_index: bool, } - impl ReduceIndex { - // The value gets replaced if f(s[current_acc], s[i]) returns true. + // The value gets replaced if `f(s[current_acc], s[i])` returns true. #[inline(always)] fn fold_impl(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result> where - T: Clone + Copy, - U: Clone + Copy, + T: Copy, + U: Copy, F: Fn(T, T) -> bool, G: Fn(T, usize) -> U, { @@ -121,66 +128,72 @@ impl ReduceIndex { let reduce_dim_stride = src_l.stride()[self.reduce_dim_index]; let dst_len = src_l.shape().elem_count() / reduce_dim_size; let mut dst: Vec = Vec::with_capacity(dst_len); - let dst_to_set = dst.spare_capacity_mut(); - let dst_to_set = - unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(dst_to_set) }; + + let dst_ptr = dst.spare_capacity_mut().as_mut_ptr(); + match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; if reduce_dim_stride == 1 { - for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { - let start_src_i = start_src_i * reduce_dim_size; - let src = &src[start_src_i..start_src_i + reduce_dim_size]; + for i in 0..dst_len { + let start_src_i = i * reduce_dim_size; + let src_slice = &src[start_src_i..start_src_i + reduce_dim_size]; + let mut acc = 0; - let mut val = src[0]; - for (src_i, &s) in src.iter().enumerate() { + let mut val = src_slice[0]; + for (j, &s) in src_slice.iter().enumerate() { if f(val, s) { - acc = src_i; - val = s + acc = j; + val = s; } } - *dst_v = g(val, acc) + unsafe { + dst_ptr.add(i).write(MaybeUninit::new(g(val, acc))); + } } } else { - for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { - let (p, q) = ( - start_src_i / reduce_dim_stride, - start_src_i % reduce_dim_stride, - ); - // start_src_i = p * reduce_dim_stride + q - let start_src_i = p * reduce_dim_stride * reduce_dim_size + q; - let src = &src[start_src_i..]; + for i in 0..dst_len { + let p = i / reduce_dim_stride; + let q = i % reduce_dim_stride; + let start_src_i = p * reduce_dim_size * reduce_dim_stride + q; + let mut acc = 0; - let mut val = src[0]; - for src_i in 0..reduce_dim_size { - let s = src[src_i * reduce_dim_stride]; + let mut val = src[start_src_i]; + for j in 0..reduce_dim_size { + let s = src[start_src_i + j * reduce_dim_stride]; if f(val, s) { - acc = src_i; - val = s + acc = j; + val = s; } } - *dst_v = g(val, acc) + unsafe { + dst_ptr.add(i).write(MaybeUninit::new(g(val, acc))); + } } } } None => { let l = src_l.narrow(self.reduce_dim_index, 0, 1)?; - for (unstr_index, src_index) in l.strided_index().enumerate() { - let src = &src[src_index..]; + for (i, src_index) in l.strided_index().enumerate() { let mut acc = 0; - let mut val = src[0]; - for src_i in 0..reduce_dim_size { - let s = src[src_i * reduce_dim_stride]; + let mut val = src[src_index]; + for j in 0..reduce_dim_size { + let s = src[src_index + j * reduce_dim_stride]; if f(val, s) { - acc = src_i; - val = s + acc = j; + val = s; } } - dst_to_set[unstr_index] = g(val, acc) + unsafe { + dst_ptr.add(i).write(MaybeUninit::new(g(val, acc))); + } } } } - unsafe { dst.set_len(dst_len) }; + + unsafe { + dst.set_len(dst_len); + } Ok(dst) } }