From f2189d398f7a61c9a4b60c3545984e53a5d03c60 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 28 Feb 2024 10:30:53 -0300 Subject: [PATCH] Implement suggested changes. --- torch_xla/csrc/aten_xla_type.cpp | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9c7a3445c6ec..b91239bc7a85 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -711,7 +711,8 @@ at::Tensor XLANativeFunctions::as_strided_copy( if (dim == 0 && tensor.numel() > 0) { // If there's no specified dimension, return the first element of the // storage. This behavior is consistent with eager. - return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0); + return take(tensor, + at::tensor({0}, at::TensorOptions().device(tensor.device()))); } if (storage_size == 0) { @@ -723,17 +724,9 @@ at::Tensor XLANativeFunctions::as_strided_copy( } // At this point, the following is true: - // - storage_size > 0 - // - tensor.numel() > 0 - // - dim > 0 - - // Flatten the tensor, so that it's easier to gather its elements. - tensor = view_copy_symint(tensor, {tensor.numel()}); - - if (storage_offset.has_value() && *storage_offset > 0) { - // If there's a storage_offset, slice this tensor, first. - tensor = slice_copy(tensor, 0, *storage_offset, c10::nullopt, 1); - } + XLA_CHECK(storage_size > 0); + XLA_CHECK(tensor.numel() > 0); + XLA_CHECK(dim > 0); // Index tensor for gathering the needed elements into contiguous data. // @@ -756,7 +749,9 @@ at::Tensor XLANativeFunctions::as_strided_copy( // std::vector view_shape(dim, 1); auto index_tensor = - at::tensor({0}, at::TensorOptions().dtype(at::kLong)).view(view_shape); + at::tensor({storage_offset.value_or(self.storage_offset())}, + at::TensorOptions().dtype(at::kLong)) + .view(view_shape); // Then, add to the index_tensor the offset value introduced for each possible // index of that corresponding dimension. @@ -784,9 +779,7 @@ at::Tensor XLANativeFunctions::as_strided_copy( } // Finally, index the tensor with the computed indices. - c10::List> indices( - {c10::optional(index_tensor)}); - return index(tensor, indices); + return take(tensor, index_tensor.to(tensor.device())); } at::Tensor XLANativeFunctions::as_strided_scatter(