From 57c562b618987c3c6f14e6a04f1a73ac6923e81b Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 30 Oct 2024 00:42:22 +0800 Subject: [PATCH] avoid unnecessary unsqueeze --- src/dataset.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/dataset.rs b/src/dataset.rs index 5a44a22..ba1d764 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -106,26 +106,24 @@ impl Batcher> for FSRSBatcher { item.history().map(|r| (r.delta_t, r.rating)).unzip(); delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); - let delta_t = Tensor::::from_floats( + let delta_t = Tensor::::from_floats( TensorData::new( delta_t, Shape { - dims: vec![pad_size], + dims: vec![1, pad_size], }, ), &self.device, - ) - .unsqueeze(); - let rating = Tensor::::from_data( + ); + let rating = Tensor::::from_data( TensorData::new( rating, Shape { - dims: vec![pad_size], + dims: vec![1, pad_size], }, ), &self.device, - ) - .unsqueeze(); + ); (delta_t, rating) }) .unzip();