Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply "Lower RandPerm" (#6394) #6427

Closed
wants to merge 1 commit into from
Closed

Conversation

changm
Copy link
Collaborator

@changm changm commented Jan 31, 2024

This reverts commit 2f4275f.

Also fixes test infrastructure to call xla.mark_step() as required.

@changm changm self-assigned this Jan 31, 2024
@changm
Copy link
Collaborator Author

changm commented Jan 31, 2024

Chatted with Wonjoo offline and he suggested to copy the test into the internal XLA directory, which is done here. Fundamentally we think the test fails due to the lazy nature of PyTorch/XLA and the test harness not calling mark_step(), which causes the call to randperm to be compiled every loop iteration. The recompilation causes a different set of test indices to be swapped, causing the test to fail. Consider the following code:

  1 import torch
  2 import torch_xla
  3 import torch_xla.core.xla_model as xm
  4 import logging
  5
  6 def ref_index_copy(tgt, dim, idx, src):
  7   for i in range(idx.size(0)):
        # idx is recompiled 10x and hence we get a new random number everytime.
  8     idx_dest = dim * (slice(None),) + (idx[i],)
  9     idx_src = dim * (slice(None),) + (i,)
 10     tgt[idx_dest] = src[idx_src]
 11
 12 target = torch.tensor([8, 6, 1, 0, 2, 3, 5, 4, 7, 9])
 13 swap_index = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 14 randperm_results = torch.randperm(10, device=xm.xla_device())
 15
 16 # Fails because we call RandPerm::Lower 10 times and randperm_results is different every loop iteration.
 17 ref_index_copy(target, 0, randperm_results, swap_index)
 18 print(target)

It's not clear how to force line #8 to not recompile every loop iteration. Inserting a mark_step at line 15 fixes the problem.

@wonjoolee95
Copy link
Collaborator

Thanks! Can we combine the two tests into one python file? Maybe name it something like test_python_ops.py so we can put all the python test code there.

@changm
Copy link
Collaborator Author

changm commented Feb 1, 2024

I moved everything into a single file. Also note that #6369 disabled a bunch of CPU tests and now test_put_cpu_bfloat16 is failing at MASTER, so I had to disable that test here.

@changm changm force-pushed the randperm branch 3 times, most recently from 32b119f to 9136c51 Compare February 2, 2024 17:34
@changm
Copy link
Collaborator Author

changm commented Feb 2, 2024

Alright fixed the tests and this PR should be good to review again. Thanks!

This reverts commit 2f4275f.

Also fixes test infrastructure to call xla.mark_step() as required.
@changm
Copy link
Collaborator Author

changm commented Feb 6, 2024

Closing to reopen on master repo so that CI/CD properly runs.

@changm changm closed this Feb 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants