Skip to content

Commit

Permalink
MIni refact
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 18, 2024
1 parent 5729849 commit b9eecba
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 42 deletions.
12 changes: 4 additions & 8 deletions crates/cubecl-linalg/src/matmul/matmul_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::marker::PhantomData;

use crate::{matmul::test_utils::range_tensor_with_factor, tensor::TensorHandle};
use cubecl_core::{frontend::F32, CubeElement, Runtime};

Expand Down Expand Up @@ -82,12 +80,10 @@ impl MatmulTestCase {
range_tensor_with_factor::<R>(self.batch, self.m, self.k, self.factor, device);
let tensor_2 =
range_tensor_with_factor::<R>(self.batch, self.k, self.n, self.factor, device);
let out = TensorHandle {
handle: create_empty::<R>(self.batch * self.m, self.n, device),
shape: vec![self.batch, self.m, self.n],
strides: vec![self.m * self.n, self.n, 1],
elem: PhantomData,
};
let out = TensorHandle::new_contiguous(
vec![self.batch, self.m, self.n],
create_empty::<R>(self.batch * self.m, self.n, device),
);

let expected = self.matmul_cpu(
f32::from_bytes(&R::client(device).read(tensor_1.handle.clone().binding())),
Expand Down
38 changes: 6 additions & 32 deletions crates/cubecl-linalg/src/matmul/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::{marker::PhantomData, ops::Range};

use bytemuck::cast_slice;
use cubecl_core::{
frontend::{F16, F32},
ir::{Elem, FloatKind},
server::Handle,
CubeElement, Feature, Runtime,
};
use std::ops::Range;

use crate::tensor::TensorHandle;

Expand All @@ -27,12 +26,7 @@ pub(crate) fn range_tensor_f16<R: Runtime>(

let handle = client.create(cast_slice(&data));

TensorHandle {
handle,
shape: vec![x, y],
strides: vec![y, 1],
elem: PhantomData,
}
TensorHandle::new_contiguous(vec![x, y], handle)
}

pub(crate) fn range_tensor<R: Runtime>(
Expand All @@ -50,12 +44,7 @@ pub(crate) fn range_tensor<R: Runtime>(

let handle = client.create(cast_slice(&data));

TensorHandle {
handle,
shape: vec![x, y],
strides: vec![y, 1],
elem: PhantomData,
}
TensorHandle::new_contiguous(vec![x, y], handle)
}

pub(crate) fn range_tensor_with_factor<R: Runtime>(
Expand All @@ -75,12 +64,7 @@ pub(crate) fn range_tensor_with_factor<R: Runtime>(

let handle = client.create(cast_slice(&data));

TensorHandle {
handle,
shape: vec![batch, x, y],
strides: vec![x * y, y, 1],
elem: PhantomData,
}
TensorHandle::new_contiguous(vec![batch, x, y], handle)
}

pub(crate) fn range_tensor_transposed<R: Runtime>(
Expand All @@ -101,12 +85,7 @@ pub(crate) fn range_tensor_transposed<R: Runtime>(

let handle = client.create(cast_slice(&data));

TensorHandle {
handle,
shape: vec![x, y],
strides: vec![y, 1],
elem: PhantomData,
}
TensorHandle::new_contiguous(vec![x, y], handle)
}

pub(crate) fn zeros_tensor<R: Runtime>(
Expand All @@ -120,12 +99,7 @@ pub(crate) fn zeros_tensor<R: Runtime>(
let data: Vec<f32> = vec![0.; n_elements];
let handle = client.create(cast_slice(&data));

TensorHandle {
handle,
shape: vec![x, y],
strides: vec![y, 1],
elem: PhantomData,
}
TensorHandle::new_contiguous(vec![x, y], handle)
}

pub(crate) fn create_empty<R: Runtime>(
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-linalg/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ where
pub shape: Vec<usize>,
/// The strides of the tensor.
pub strides: Vec<usize>,
pub(crate) elem: PhantomData<E>,
elem: PhantomData<E>,
}

impl<R, E> core::fmt::Debug for TensorHandle<R, E>
Expand All @@ -27,10 +27,11 @@ where
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"Tensor {{ shape: {:?}, strides: {:?}, runtime: {}}}",
"Tensor {{ shape: {:?}, strides: {:?}, runtime: {}, dtype: {}}}",
self.shape,
self.strides,
R::name(),
core::any::type_name::<E>(),
))
}
}
Expand Down

0 comments on commit b9eecba

Please sign in to comment.