Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Fori loop simple case without hard-code #7031

Closed

Conversation

ManfeiBai
Copy link
Collaborator

@ManfeiBai ManfeiBai commented May 6, 2024

code based on #7012, create post-order tracing interface on python level, to enable post-order tracing for XLAComputation like https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1293

// runtime::ComputationClient::ComputationPtr computation;
// };
// std::vector<runtime::ComputationClient::ComputationPtr>
runtime::ComputationClient::ComputationPtr GetXLAComputation(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @miladm, this is the script code for post-order tracing compile function

// Compile(*tensors, devices, coll, &po_data, ir_values).computation)
XLAGraphExecutor::saveComputation* compile_result = Compile(*tensors, devices, coll, &po_data, ir_values).computation
return compile_result

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @miladm, this is the script code implementation for post-order tracing compile function

@ManfeiBai ManfeiBai changed the title [test] Fori loop simple case testnewone [test] Fori loop simple case without hard-cde May 6, 2024
@ManfeiBai ManfeiBai changed the title [test] Fori loop simple case without hard-cde [test] Fori loop simple case without hard-code May 6, 2024
@JackCaoG
Copy link
Collaborator

JackCaoG commented May 7, 2024

I have

  def test_while_loop_tpu_simple_linear(self):

    xm.mark_step()
    device = xm.xla_device()
    #device = ''
    torch.set_grad_enabled(False)

    class SimpleWithLinear(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.linear = torch.nn.Linear(2, 2)
          self.register_buffer("dec", torch.tensor(1))

      def forward(self, iter, x):
          def cond_fn(it, x):
              return it - self.dec > 0

          def body_fn(it, x):
              return it - 1, self.linear(x)

          return while_loop(cond_fn, body_fn, (iter, x))
      
    simple_with_linear = SimpleWithLinear()
    simple_with_linear.to(device)
    #breakpoint()
    input = torch.randn(2, 2).to(device)
    iter = torch.tensor(3, device=device)
    res = simple_with_linear(iter, input)

which works on CPU device but error out for XLA devices. Let's try to get this simple example working.

Copy link
Collaborator Author

@ManfeiBai ManfeiBai May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

script used for compare order of upper/lower for post-order

@ManfeiBai
Copy link
Collaborator Author

thanks, based on this PR, created and updated to simple test case in #7094, would close this pr for now

@ManfeiBai ManfeiBai closed this May 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants