Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend HLO metadata to include class hierarchy information #5715

Merged
merged 5 commits into from
Nov 20, 2023
Merged

Extend HLO metadata to include class hierarchy information #5715

merged 5 commits into from
Nov 20, 2023

Conversation

mrnikwaws
Copy link
Contributor

Summary

This change adds extended metadata to lowered HLO, along with JSON export of HLO.

Notes

Sample HLO as JSON

{'opType': 'xla__device_data', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/xla__device_data', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/module.py', 'sourceLine': 1159}
{'opType': 'xla__device_data', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/xla__device_data', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/module.py', 'sourceLine': 1159}
{'opType': 'aten__permute', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__permute', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'xla__device_data', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/xla__device_data', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/module.py', 'sourceLine': 1159}
{'opType': 'xla__device_data', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/xla__device_data', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/module.py', 'sourceLine': 1159}
{'opType': 'aten__permute', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__permute', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'xla__device_data', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/xla__device_data', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'prim__Constant', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/prim__Constant', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__mul', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__mul', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'prim__Constant', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/prim__Constant', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__add', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__add', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'prim__Constant', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/prim__Constant', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'prim__Constant', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/prim__Constant', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__uniform', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/aten__uniform', 'sourceFile': '/ansible/torch_xla/pytorch/xla/test/test_hlo_metadata.py', 'sourceLine': 40}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[0]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__relu', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/ReLU[1]/aten__relu', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/functional.py', 'sourceLine': 1476}
{'opType': 'aten__relu', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/ReLU[1]/aten__relu', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/functional.py', 'sourceLine': 1476}
{'opType': 'aten__relu', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/ReLU[1]/aten__relu', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/functional.py', 'sourceLine': 1476}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__addmm', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Linear[2]/aten__addmm', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/linear.py', 'sourceLine': 114}
{'opType': 'aten__tanh', 'opName': 'TestProgram[.1]/TextTestRunner[testRunner]/TestSuite[test]/TestHloMetaData[_tests.0]/Sequential[model]/Tanh[3]/aten__tanh', 'sourceFile': '/ansible/torch_xla/venv/lib/python3.9/site-packages/torch/nn/modules/activation.py', 'sourceLine': 356}

Limitations

  • The algorithm uses python stack frames to discover class and variable names
  • If variable names are not in the stack (e.g. pre/post execution hooks where stack has unwound) it won't resolve variable names
  • Memory consumption and search time can be further optimized in future revisions
  • Explicit code has been added to ignore packages which are 'analysis' code

Testing

  • A basic test is added, manual testing has been executed on a variety of models

@JackCaoG JackCaoG requested review from qihqi and lsy323 October 20, 2023 22:36
Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I would like to learn more about what is the intended usecase of this feature:

  • What / who would be the consumers of this new information?
  • How are they consumed? is it consumed by humans or by another downstream system that expects a very specific format?

test/test_hlo_metadata.py Outdated Show resolved Hide resolved
test/test_hlo_metadata.py Outdated Show resolved Hide resolved
torch_xla/csrc/lowering_context.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/lowering_context.cpp Outdated Show resolved Hide resolved
@mrnikwaws
Copy link
Contributor Author

mrnikwaws commented Nov 1, 2023

I've responded to specific code feedback above.

In response to the top level question: This (debug) metadata is consumed first by the compiler which maintains "backwards pointers" to meta-data through the lowering process. Ultimately this gets embedded in an executable archive allowing a profiler to perform reverse lookups, and then layout approximately where in the model we are at point in the execution (note that due to compilation there will be points in time where multiple code structures are running at one time).

This can be particularly useful for optimization where a user can relate from a bottleneck at the machine instruction level to class, variable name and ultimately full call-stacks (discussed in OpenXLA, but we still need to upstream an implementation since it requires change to the protobuf layout).

Presumably other tools (static analysis and the like) could be used to parse the HLO graph and look back at source code to perform other tasks.

@qihqi
Copy link
Collaborator

qihqi commented Nov 8, 2023

Hi @mrnikwaws,

My proposal:
Observing that there are 2 separable part of logic: one is to attach a string to HLO (currently setting to opName field), and another is to compute stacktrace using python's C api.

We can reduce the scope of the PR to just the first:

i.e. say we add a Python function torch_xla.debugging.attach_custom_op_name(xla_tensor, arbitrary_string). Then, the logic to compute python stack trace can remain out of tree. You can continue to use your C++ implementation, (say, after exposing it to Python with pybind11 or like), or, you can use stdlib's traceback library to generate Python stack trace.

The attach_custom_op_name API would be intended for SDK builders that want to provide additional functionality on top of Torch/XLA.

Suppose you want that API to be called after every ATen op for a user's Pytorch program, without the user calling it directly. Then, you can use a tensor subclass (see examples at https://github.com/albanD/subclass_zoo) wrapping XLATensor and call attach_custom_op_name for the users.

It can look like the following:

class AttachStackStraceTOHLO(TorchDispatchMode):

   def __torch_dispatch__(self, func, args, kwargs):
      res = func(*args, **kwargs)
      if res.device == xla:
         compute_stack_trace = ...
         torch_xla.attach_custom_op_name(res, compute_stack_trace)
      return res

And the end user can do

with AttachStackTraceTOHLO():
    run_model()
    etc

This way both the custom torch dispatch mode, as well as the code to compute stack trace, can remain out of tree.

===

My rationale:
There can be several backend compilers; other compilers might need different informations attached to HLOs. Or they need the same information but in slightly different format.

So, either 1) something is the standard output of torchxla's produced HLO, if so, different backend compilers has to agree on its format and BC/FC guarantees. OR, 2) something is a custom info that only be consumed by a particular compiler.

In this case, it seems to me that it falls in category of 2), and therefore, it should not be emitted by default. So instead, I propose a way to enable to you to do what you want, at the same time enable other compilers to attach similar but different informations for their backends.

@mrnikwaws
Copy link
Contributor Author

This looks like it could work. We already use a python context for profiling, though it will make some use cases (where folks just use env variables to lower debug HLO) a little more complex. However I think I can manage that.

I wasn't aware that we could plug into the dispatcher like this (thanks!).

I'm also unclear how you can connect information from a return tensor (or tensors in some container) to the related XlaNode in the in memory graph for XLA.

Assuming you can describe such a mechanism, I am using "UserMetaData" on each torch::lazy::Node https://github.com/pytorch/pytorch/blob/main/torch/csrc/lazy/core/ir_metadata.h#L24-L28 which is freeform metadata on each Node, which is then lowering that once the node is realized and lowered. This is needed since torch lazy Nodes rather than XlaNodes are passed to the lowering context.

If we keep the same mechanism my existing ExtendedFrameInfo struct this would minimally become a string. Is this what you were imagining?

If you could help me an approximate version of how and where to achieve the "custom_op_name" piece (i.e. attaching meta-data for an XlaNode from a return tensor) I'd be OK building a locally and testing.

Some of my other comments still stand (e.g. efficient string encoding in the protobuf) but if we can get to parity with op_name string for now that solves the core requirement. Likely these other improvements belong in an RFC.

@qihqi
Copy link
Collaborator

qihqi commented Nov 10, 2023

You can actually get the stacktrace from Python (without querying XLANode):

import torch
import traceback

from torch.utils._python_dispatch import TorchDispatchMode

class Model(torch.nn.Module):

    def forward(self, x, y):
        return x + y


class Disp(TorchDispatchMode):

  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    print(func.name())
    print(traceback.extract_stack())
    return func(*args, **kwargs)


with Disp():
    x = torch.randn(100)
    y = torch.randn(100)
    z = Model()(x, y)

Give

aten::randn
[<FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 21 in <module>>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 16 in __torch_dispatch__>]
aten::randn
[<FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 22 in <module>>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 16 in __torch_dispatch__>]
aten::add.Tensor
[<FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 23 in <module>>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/torch/nn/modules/module.py, line 1510 in _wrapped_call_impl>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/torch/nn/modules/module.py, line 1519 in _call_impl>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 9 in forward>, <FrameSummary file /mnt/hanq/git/qihqi/pytorch/disp_stack.py, line 16 in __torch_dispatch__>]

So basically we have the location of a particular Aten op is called. You probably want to drop the last frame as that is the one that points inside of __torch_dispatch__ and/or the ones inside of torch/nn* as user might not care about those ones.

Does this help?

@mrnikwaws
Copy link
Contributor Author

As discussed offline I'm working on the refactor now

@mrnikwaws
Copy link
Contributor Author

Refactor completed - will correct lint errors shortly

@mrnikwaws
Copy link
Contributor Author

Fix merge issues

@mrnikwaws mrnikwaws reopened this Nov 15, 2023
Copy link
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits, overall pretty solid. Thanks!

@@ -0,0 +1,90 @@
import sys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a line in run_test.sh otherwise this test wont be run by default

@@ -897,4 +897,17 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) {
xla_node->MarkDynamicDimension(dim);
}

void XLATensor::SetCustomOpName(const std::string& op_name) {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node != nullptr) xla_node->SetCustomOpName(op_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (xla_node != nullptr) {
  xla_node->SetCustomOpName(op_name);
}


const std::string& XLATensor::GetCustomOpName() const {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node != nullptr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wrap all the branches with {}

@qihqi qihqi merged commit 5e50379 into pytorch:master Nov 20, 2023
17 checks passed
zpcore pushed a commit that referenced this pull request Nov 21, 2023
* Add python binding to allow custom op_name metadata for lowered HLO
lsy323 pushed a commit to lsy323/xla that referenced this pull request Nov 28, 2023
)

* Add python binding to allow custom op_name metadata for lowered HLO
mrnikwaws added a commit to jeffhataws/xla that referenced this pull request Dec 5, 2023
)

* Add python binding to allow custom op_name metadata for lowered HLO
jeffhataws pushed a commit to jeffhataws/xla that referenced this pull request Dec 10, 2023
)

* Add python binding to allow custom op_name metadata for lowered HLO
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
)

* Add python binding to allow custom op_name metadata for lowered HLO
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Add python binding to allow custom op_name metadata for lowered HLO
@vanbasten23 vanbasten23 mentioned this pull request Mar 5, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Add python binding to allow custom op_name metadata for lowered HLO
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants