Skip to content

Commit

Permalink
Merge pull request #402 from tracel-ai/merge-reduce
Browse files Browse the repository at this point in the history
Merge reduce
  • Loading branch information
nathanielsimard authored Jan 8, 2025
2 parents 6c60851 + aae10ef commit 2552226
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cubecl_core::prelude::*;
use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};

/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
#[derive(Debug)]
pub struct ArgMax;

#[cube]
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/argmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cubecl_core::prelude::*;
use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};

/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
#[derive(Debug)]
pub struct ArgMin;

impl Reduce for ArgMin {
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-reduce/src/instructions/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

pub trait Reduce: Send + Sync + 'static {
pub trait Reduce: Send + Sync + 'static + std::fmt::Debug {
type Instruction<In: Numeric>: ReduceInstruction<In>;
}

Expand All @@ -14,7 +14,7 @@ pub trait Reduce: Send + Sync + 'static {
/// with their coordinate into an `AccumulatorItem`. Then, multiple `AccumulatorItem` are possibly fused
/// together into a single accumulator that is converted to the expected output type.
#[cube]
pub trait ReduceInstruction<In: Numeric>: Send + Sync + 'static {
pub trait ReduceInstruction<In: Numeric>: Send + Sync + 'static + std::fmt::Debug {
/// The intermediate state into which we accumulate new input elements.
/// This is most likely a `Line<T>` or a struct or tuple of lines.
type AccumulatorItem: CubeType;
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction, Sum};

#[derive(Debug)]
pub struct Mean;

impl Reduce for Mean {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction};

#[derive(Debug)]
pub struct Prod;

impl Reduce for Prod {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction};

#[derive(Debug)]
pub struct Sum;

impl Reduce for Sum {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-reduce/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod strategy;

pub use config::*;
pub use error::*;
use instructions::Reduce;
pub use instructions::Reduce;
pub use instructions::ReduceInstruction;
pub use strategy::*;

Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ default = [
]
exclusive-memory-only = ["cubecl-wgpu?/exclusive-memory-only"]
linalg = ["dep:cubecl-linalg"]
reduce = ["dep:cubecl-reduce"]
std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"]
template = ["cubecl-core/template"]

Expand All @@ -37,6 +38,7 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", default-features = f
cubecl-cuda = { path = "../cubecl-cuda", version = "0.4.0", default-features = false, optional = true }
cubecl-hip = { path = "../cubecl-hip", version = "0.4.0", default-features = false, optional = true }
cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", default-features = false, optional = true }
cubecl-reduce = { path = "../cubecl-reduce", version = "0.4.0", default-features = false, optional = true }
cubecl-runtime = { path = "../cubecl-runtime", version = "0.4.0", default-features = false }
cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.4.0", default-features = false, optional = true }
half = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ pub use cubecl_hip as hip;

#[cfg(feature = "linalg")]
pub use cubecl_linalg as linalg;

#[cfg(feature = "reduce")]
pub use cubecl_reduce as reduce;

0 comments on commit 2552226

Please sign in to comment.