diff --git a/src/dataset.rs b/src/dataset.rs index 5a44a22..ba1d764 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -106,26 +106,24 @@ impl Batcher> for FSRSBatcher { item.history().map(|r| (r.delta_t, r.rating)).unzip(); delta_t.resize(pad_size, 0); rating.resize(pad_size, 0); - let delta_t = Tensor::::from_floats( + let delta_t = Tensor::::from_floats( TensorData::new( delta_t, Shape { - dims: vec![pad_size], + dims: vec![1, pad_size], }, ), &self.device, - ) - .unsqueeze(); - let rating = Tensor::::from_data( + ); + let rating = Tensor::::from_data( TensorData::new( rating, Shape { - dims: vec![pad_size], + dims: vec![1, pad_size], }, ), &self.device, - ) - .unsqueeze(); + ); (delta_t, rating) }) .unzip();