-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make as_strided_copy
materialize a new tensor with index
.
#6624
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
Haven't looked at the code in depth, but this sounds plausible. Will review tomorrow. @bdhirsh we could use this to functionalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great algorithm! Thank you @ysiraichi!
I think it is correct modulo potential corner-cases that may pop up.
torch_xla/csrc/aten_xla_type.cpp
Outdated
if (storage_offset.has_value() && *storage_offset > 0) { | ||
// If there's a storage_offset, slice this tensor, first. | ||
tensor = slice_copy(tensor, 0, *storage_offset, c10::nullopt, 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do this, or simply add storage_offset
to index_tensor
at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I kind of thought about it for a second and, for some reason, decided it wouldn't be correct. But, on second thoughts, it does make sense.
torch_xla/csrc/aten_xla_type.cpp
Outdated
// Flatten the tensor, so that it's easier to gather its elements. | ||
tensor = view_copy_symint(tensor, {tensor.numel()}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than this flattening + index
, you can simply use torch.take
.
bc6409c
to
7afeb56
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic lgtm
7afeb56
to
9ea4600
Compare
@JackCaoG Could you take a look at this PR whenever you have some time? |
I believe these export tests are unrelated. @JackCaoG @zpcore @frgossen @vanbasten23 @cota @golechwierowicz |
not really, @lsy323 do you know what this unbounded export test is doing? |
4b66320
to
8bb2aea
Compare
@JackCaoG @alanwaketan Could you take a look at this PR when you have some time? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM, and only have one question.
// [[[0]]] | ||
// | ||
std::vector<int64_t> view_shape(dim, 1); | ||
auto index_tensor = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is computed by cpu eager in the following code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Given size, stride, and offset argument spec, we compute ahead of time the correct indices for materializing the tensor. No need for computing at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Hi @ysiraichi, I found this PR causes some performance regression on TPU v4-8 (Also can be repro in v3-8). The regression can be reproduced by running the following command:
When I'm at
When I'm at
I'm reverting this PR for now, since we are close to the 2.3 branch cut date (March 11th). Could you please re-land the PR after the perf regression is resolved? Thanks a lot |
Fix: #5835
This PR implements arbitrary
as_strided
function by decomposing it into slicing+indexing. In summary, we slice the base tensor for complying with the givenstorage_offset
, and then index a flattened version of the tensor, gathering the desired elements, based on the givensize
andstrides
. (more explanation in the code).cc @miladm @JackCaoG @lezcano