Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-land: Make as_strided_copy materialize a new tensor with index. #6697

Merged
merged 3 commits into from
Mar 19, 2024

Conversation

ysiraichi
Copy link
Collaborator

Re-land: #6624

This PR adds a fast path on top of #6624 changes.

Fast path: keep old behavior of as_strided_copy

  • Check that the size and strides specify a non-overlapping and dense tensor

Slow path: new behavior

  • Slower due to CPU dispatch and computation
  • Should work with any argument combination

cc @miladm @JackCaoG @lsy323

@ysiraichi ysiraichi requested review from lsy323 and JackCaoG March 8, 2024 14:24
@ysiraichi
Copy link
Collaborator Author

I will test for the regression described here on the GPU machine I have access to.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 8, 2024

dynamo issue can be fixed by rebasing, fine to ignore.

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-asstrided branch from 84ec60b to 1a204f7 Compare March 8, 2024 19:00
@ysiraichi
Copy link
Collaborator Author

@lsy323 Could you help me checking if the regression is gone?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 8, 2024

Do we need this pr in the 2.3 release? It is a rather dangerous change, if we don;t have a strong reason I'd rather leave it in nightly for now.

@miladm
Copy link
Collaborator

miladm commented Mar 11, 2024

@vanbasten23 can you please help @ysiraichi benchmark this fix on TPU and confirm perf outcome?

@JackCaoG given the risk, I'd be ok we leave this PR out for 2.3

@JackCaoG
Copy link
Collaborator

yea, unless there is a strong reason I would prefer to leave this out of 2.3 releas.

@JackCaoG
Copy link
Collaborator

Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....

@vanbasten23
Copy link
Collaborator

Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....

I'm running the tests in #6624 (comment).

@vanbasten23
Copy link
Collaborator

@ysiraichi sorry for the delayed response. I tested on my v3-8. Before this PR (master branch 6ac3223):

root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
Epoch 1 train begin 03:32:23
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42

With the PR:

root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
Epoch 1 train begin 04:06:25
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29

I don't see any slowdown. The change lgtm. Thanks Yukio.

@ysiraichi
Copy link
Collaborator Author

Thanks, @vanbasten23.

@ysiraichi ysiraichi merged commit 27a7dd3 into master Mar 19, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants