From 20532a9191dbddd230aed58d2d16459d49fac77f Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Tue, 21 Nov 2023 21:39:08 +0000 Subject: [PATCH] Add fix for stack depth when using set custom op_name in a python context --- test/custom_debug_lowering.py | 8 +++++--- test/test_hlo_metadata.py | 10 +++++++--- torch_xla/csrc/init_python_bindings.cpp | 4 +++- torch_xla/csrc/ir.cpp | 12 ++++++++---- torch_xla/csrc/ir.h | 6 ++++++ torch_xla/csrc/lowering_context.cpp | 12 +++++++++++- torch_xla/csrc/tensor.cpp | 16 ++++++++++++++++ torch_xla/csrc/tensor.h | 6 ++++++ 8 files changed, 62 insertions(+), 12 deletions(-) diff --git a/test/custom_debug_lowering.py b/test/custom_debug_lowering.py index cc19dd79b81..6319282decd 100644 --- a/test/custom_debug_lowering.py +++ b/test/custom_debug_lowering.py @@ -172,11 +172,13 @@ def CleanNames(names): def GetAllObjectAndClassNames(frame): names = [] + frame_count = 0 while frame is not None: name = GetClassNameAndObjFromFrame(frame) if len(name) > 0: names.append(name) frame = frame.f_back + frame_count += 1 names.reverse() @@ -187,7 +189,7 @@ def GetAllObjectAndClassNames(frame): if len(output) > 0: output += "/" - return output + return output, frame_count class CustomOpNameLowering(TorchDispatchMode): @@ -208,6 +210,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs={}): res = func(*args, **kwargs) if 'xla' in str(res.device): frame = inspect.currentframe() - prefix = GetAllObjectAndClassNames(frame) - torch_xla._XLAC._set_xla_custom_op_name(res, prefix) + prefix, depth = GetAllObjectAndClassNames(frame) + torch_xla._XLAC._set_xla_custom_op_name(res, prefix, depth - 2) return res diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 5f5ac186395..6216c2e4d30 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -44,9 +44,9 @@ def test_metadata(self): # Strings to match in the lowering bingo = { "torch/_ops.py": False, - #"torch/nn/modules/linear.py": False, - #"torch/nn/modules/activation.py": False, - #"torch/nn/functional.py": False, + "torch/nn/modules/linear.py": False, + "torch/nn/modules/activation.py": False, + "torch/nn/functional.py": False, "Sequential[model]/Linear[0]": False, "Sequential[model]/ReLU[1]": False, "Sequential[model]/Linear[2]": False, @@ -60,6 +60,10 @@ def test_metadata(self): non_zero_metadata = False local_json = json.loads(hlo_text) + + with open("./hlo.json", "w") as f: + f.write(json.dumps(local_json, indent=2)) + assert "computations" in local_json for c in local_json["computations"]: if "instructions" in c: diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index fde4ee1c9d7..6bc7d2df483 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1919,9 +1919,11 @@ void InitXlaModuleBindings(py::module m) { return XLANativeFunctions::set_(self, source); }); m.def("_set_xla_custom_op_name", - [](const at::Tensor& input, const std::string& op_name) { + [](const at::Tensor& input, const std::string& op_name, + size_t max_call_stack_depth) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); xtensor->SetCustomOpName(op_name); + xtensor->SetCustomCallStackDepth(max_call_stack_depth); }); m.def("_get_xla_custom_op_name", [](const at::Tensor& input) -> const std::string& { diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index b7cd2025bd3..a52c0cace35 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -49,7 +49,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, : torch::lazy::Node(op, operands, std::move(shapes), num_outputs), xla_shape_(std::move(xla_shape)), node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), - dag_hash_(GetOperandHashes(operands, node_hash_)) {} + dag_hash_(GetOperandHashes(operands, node_hash_)), + max_call_stack_depth_(0) {} XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, std::vector&& shapes, @@ -57,7 +58,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, size_t num_outputs, torch::lazy::hash_t hash_seed) : torch::lazy::Node(op, operands, std::move(shapes), num_outputs), node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), - dag_hash_(GetOperandHashes(operands, node_hash_)) { + dag_hash_(GetOperandHashes(operands, node_hash_)), + max_call_stack_depth_(0) { xla_shape_ = GetOpShape(xla_shape_fn); } @@ -68,7 +70,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, num_outputs), xla_shape_(std::move(xla_shape)), node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)), - dag_hash_(GetOperandHashes(operands, node_hash_)) {} + dag_hash_(GetOperandHashes(operands, node_hash_)), + max_call_stack_depth_(0) {} XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands, xla::Shape xla_shape, size_t num_outputs, @@ -102,7 +105,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape, : torch::lazy::Node(op, shape, num_outputs), xla_shape_(std::move(xla_shape)), node_hash_(GetOpHash(op, xla_shape_, hash_seed)), - dag_hash_(node_hash_) {} + dag_hash_(node_hash_), + max_call_stack_depth_(0) {} XlaNode::XlaNode(torch::lazy::OpKind op, xla::Shape xla_shape, size_t num_outputs, torch::lazy::hash_t hash_seed) diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 1e4a0439e23..1f36db2283f 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -149,6 +149,11 @@ class XlaNode : public torch::lazy::Node { void SetCustomOpName(const std::string& op_name); const std::string& custom_op_name() const { return custom_op_name_; } + void SetCustomCallStackDepth(size_t max_call_stack_depth) { + max_call_stack_depth_ = max_call_stack_depth; + } + const size_t max_call_stack_depth() const { return max_call_stack_depth_; } + protected: std::unordered_set unbounded_dynamic_dims_; @@ -172,6 +177,7 @@ class XlaNode : public torch::lazy::Node { std::vector> output_shardings_; std::string custom_op_name_; + size_t max_call_stack_depth_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 4e9f6f71d96..0c27cc072f4 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -62,8 +63,14 @@ class HloMetadataSetter { const XlaNode* xla_node_cast = dynamic_cast(node); + size_t max_stack_depth = nmeta.frame_info.size(); + if (xla_node_cast != nullptr && !xla_node_cast->custom_op_name().empty()) { op_name_prefix = xla_node_cast->custom_op_name(); + + if (xla_node_cast->max_call_stack_depth() != 0) { + max_stack_depth = xla_node_cast->max_call_stack_depth(); + } } if (!nmeta.scope.empty()) { @@ -75,9 +82,12 @@ class HloMetadataSetter { if (!nmeta.frame_info.empty()) { auto frame_it = nmeta.frame_info.rbegin(); int parent_frame_id = kInvalidIndex; - for (; frame_it != nmeta.frame_info.rend(); ++frame_it) { + int depth = 0; + for (; frame_it != nmeta.frame_info.rend() && depth <= max_stack_depth; + ++frame_it) { parent_frame_id = loctx->AddStackFrameLocation(*frame_it, parent_frame_id); + ++depth; } // Point to first entry / deepest call / top frame in call stack diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 4a97aad68b7..8265264b2ac 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -912,4 +912,20 @@ const std::string& XLATensor::GetCustomOpName() const { } } +void XLATensor::SetCustomCallStackDepth(size_t max_call_stack_depth) { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) { + xla_node->SetCustomCallStackDepth(max_call_stack_depth); + } +} + +size_t XLATensor::GetCustomCallStackDepth() const { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) { + return xla_node->max_call_stack_depth(); + } else { + return size_t(0); + } +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 83db2e95df6..e4aac686152 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -286,6 +286,12 @@ class XLATensor : public torch::lazy::LazyTensor { void SetCustomOpName(const std::string& op_name); const std::string& GetCustomOpName() const; + // When using TorchDispatch - e.g. to set a custom op name we end up + // adding additional frames in stack frame debug - this limits + // stack depth + void SetCustomCallStackDepth(size_t max_call_stack_depth); + size_t GetCustomCallStackDepth() const; + private: XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device); XLATensor(torch::lazy::BackendDataPtr handle,