Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge reduce #402

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/argmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cubecl_core::prelude::*;
use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};

/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
#[derive(Debug)]
pub struct ArgMax;

#[cube]
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/argmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cubecl_core::prelude::*;
use super::{lowest_coordinate_matching, ArgAccumulator, Reduce, ReduceInstruction};

/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
#[derive(Debug)]
pub struct ArgMin;

impl Reduce for ArgMin {
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-reduce/src/instructions/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

pub trait Reduce: Send + Sync + 'static {
pub trait Reduce: Send + Sync + 'static + std::fmt::Debug {
type Instruction<In: Numeric>: ReduceInstruction<In>;
}

Expand All @@ -14,7 +14,7 @@ pub trait Reduce: Send + Sync + 'static {
/// with their coordinate into an `AccumulatorItem`. Then, multiple `AccumulatorItem` are possibly fused
/// together into a single accumulator that is converted to the expected output type.
#[cube]
pub trait ReduceInstruction<In: Numeric>: Send + Sync + 'static {
pub trait ReduceInstruction<In: Numeric>: Send + Sync + 'static + std::fmt::Debug {
/// The intermediate state into which we accumulate new input elements.
/// This is most likely a `Line<T>` or a struct or tuple of lines.
type AccumulatorItem: CubeType;
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction, Sum};

#[derive(Debug)]
pub struct Mean;

impl Reduce for Mean {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction};

#[derive(Debug)]
pub struct Prod;

impl Reduce for Prod {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-reduce/src/instructions/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use cubecl_core::prelude::*;

use super::{Reduce, ReduceInstruction};

#[derive(Debug)]
pub struct Sum;

impl Reduce for Sum {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-reduce/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod strategy;

pub use config::*;
pub use error::*;
use instructions::Reduce;
pub use instructions::Reduce;
pub use instructions::ReduceInstruction;
pub use strategy::*;

Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ default = [
]
exclusive-memory-only = ["cubecl-wgpu?/exclusive-memory-only"]
linalg = ["dep:cubecl-linalg"]
reduce = ["dep:cubecl-reduce"]
std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"]
template = ["cubecl-core/template"]

Expand All @@ -37,6 +38,7 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", default-features = f
cubecl-cuda = { path = "../cubecl-cuda", version = "0.4.0", default-features = false, optional = true }
cubecl-hip = { path = "../cubecl-hip", version = "0.4.0", default-features = false, optional = true }
cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", default-features = false, optional = true }
cubecl-reduce = { path = "../cubecl-reduce", version = "0.4.0", default-features = false, optional = true }
cubecl-runtime = { path = "../cubecl-runtime", version = "0.4.0", default-features = false }
cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.4.0", default-features = false, optional = true }
half = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ pub use cubecl_hip as hip;

#[cfg(feature = "linalg")]
pub use cubecl_linalg as linalg;

#[cfg(feature = "reduce")]
pub use cubecl_reduce as reduce;
Loading