Skip to content

Commit

Permalink
Implement suggested changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Feb 28, 2024
1 parent d7638a4 commit f2189d3
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,8 @@ at::Tensor XLANativeFunctions::as_strided_copy(
if (dim == 0 && tensor.numel() > 0) {
// If there's no specified dimension, return the first element of the
// storage. This behavior is consistent with eager.
return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0);
return take(tensor,
at::tensor({0}, at::TensorOptions().device(tensor.device())));
}

if (storage_size == 0) {
Expand All @@ -723,17 +724,9 @@ at::Tensor XLANativeFunctions::as_strided_copy(
}

// At this point, the following is true:
// - storage_size > 0
// - tensor.numel() > 0
// - dim > 0

// Flatten the tensor, so that it's easier to gather its elements.
tensor = view_copy_symint(tensor, {tensor.numel()});

if (storage_offset.has_value() && *storage_offset > 0) {
// If there's a storage_offset, slice this tensor, first.
tensor = slice_copy(tensor, 0, *storage_offset, c10::nullopt, 1);
}
XLA_CHECK(storage_size > 0);
XLA_CHECK(tensor.numel() > 0);
XLA_CHECK(dim > 0);

// Index tensor for gathering the needed elements into contiguous data.
//
Expand All @@ -756,7 +749,9 @@ at::Tensor XLANativeFunctions::as_strided_copy(
//
std::vector<int64_t> view_shape(dim, 1);
auto index_tensor =
at::tensor({0}, at::TensorOptions().dtype(at::kLong)).view(view_shape);
at::tensor({storage_offset.value_or(self.storage_offset())},
at::TensorOptions().dtype(at::kLong))
.view(view_shape);

// Then, add to the index_tensor the offset value introduced for each possible
// index of that corresponding dimension.
Expand Down Expand Up @@ -784,9 +779,7 @@ at::Tensor XLANativeFunctions::as_strided_copy(
}

// Finally, index the tensor with the computed indices.
c10::List<c10::optional<at::Tensor>> indices(
{c10::optional<at::Tensor>(index_tensor)});
return index(tensor, indices);
return take(tensor, index_tensor.to(tensor.device()));
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down

0 comments on commit f2189d3

Please sign in to comment.