From 49d844b3d3281100a61a33a4d7865046fcd44b2c Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 22 Jul 2024 09:03:23 -0400 Subject: [PATCH] Fix no-std --- crates/cubecl-common/Cargo.toml | 4 ++-- crates/cubecl-linalg/src/tensor/contiguous.rs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/cubecl-common/Cargo.toml b/crates/cubecl-common/Cargo.toml index 902431cd0..30302b243 100644 --- a/crates/cubecl-common/Cargo.toml +++ b/crates/cubecl-common/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [features] default = ["std"] -std = ["rand/std"] +std = ["rand/std", "pollster"] [target.'cfg(target_family = "wasm")'.dependencies] getrandom = { workspace = true, features = ["js"] } @@ -24,7 +24,7 @@ spin = { workspace = true } # using in place of use std::sy derive-new = { workspace = true } serde = { workspace = true } rand = { workspace = true } -pollster = { workspace = true } +pollster = { workspace = true, optional = true } [dev-dependencies] dashmap = { workspace = true } diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 87e344395..c040b89bf 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -62,7 +62,7 @@ pub fn into_contiguous( // Vectorization is only enabled when the last dimension is contiguous. let rank = input.strides.len(); let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &input.shape, &input.strides, rank - 1); + tensor_vectorization_factor(&[4, 2], input.shape, input.strides, rank - 1); let num_elems: usize = input.shape.iter().product(); let cube_count = calculate_cube_count_elemwise( @@ -73,14 +73,14 @@ pub fn into_contiguous( let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle); into_contiguous_kernel::launch::( - &client, + client, cube_count, CubeDim::default(), TensorArg::vectorized( vectorization_factor, - &input.handle, - &input.strides, - &input.shape, + input.handle, + input.strides, + input.shape, ), TensorArg::vectorized( vectorization_factor,