Skip to content

Commit

Permalink
Add fix for stack depth when using set custom op_name in a python con…
Browse files Browse the repository at this point in the history
…text
  • Loading branch information
mrnikwaws committed Nov 21, 2023
1 parent 6b72c56 commit 20532a9
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 12 deletions.
8 changes: 5 additions & 3 deletions test/custom_debug_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,13 @@ def CleanNames(names):

def GetAllObjectAndClassNames(frame):
names = []
frame_count = 0
while frame is not None:
name = GetClassNameAndObjFromFrame(frame)
if len(name) > 0:
names.append(name)
frame = frame.f_back
frame_count += 1

names.reverse()

Expand All @@ -187,7 +189,7 @@ def GetAllObjectAndClassNames(frame):
if len(output) > 0:
output += "/"

return output
return output, frame_count


class CustomOpNameLowering(TorchDispatchMode):
Expand All @@ -208,6 +210,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs={}):
res = func(*args, **kwargs)
if 'xla' in str(res.device):
frame = inspect.currentframe()
prefix = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix)
prefix, depth = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix, depth - 2)
return res
10 changes: 7 additions & 3 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def test_metadata(self):
# Strings to match in the lowering
bingo = {
"torch/_ops.py": False,
#"torch/nn/modules/linear.py": False,
#"torch/nn/modules/activation.py": False,
#"torch/nn/functional.py": False,
"torch/nn/modules/linear.py": False,
"torch/nn/modules/activation.py": False,
"torch/nn/functional.py": False,
"Sequential[model]/Linear[0]": False,
"Sequential[model]/ReLU[1]": False,
"Sequential[model]/Linear[2]": False,
Expand All @@ -60,6 +60,10 @@ def test_metadata(self):
non_zero_metadata = False

local_json = json.loads(hlo_text)

with open("./hlo.json", "w") as f:
f.write(json.dumps(local_json, indent=2))

assert "computations" in local_json
for c in local_json["computations"]:
if "instructions" in c:
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,11 @@ void InitXlaModuleBindings(py::module m) {
return XLANativeFunctions::set_(self, source);
});
m.def("_set_xla_custom_op_name",
[](const at::Tensor& input, const std::string& op_name) {
[](const at::Tensor& input, const std::string& op_name,
size_t max_call_stack_depth) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->SetCustomOpName(op_name);
xtensor->SetCustomCallStackDepth(max_call_stack_depth);
});
m.def("_get_xla_custom_op_name",
[](const at::Tensor& input) -> const std::string& {
Expand Down
12 changes: 8 additions & 4 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands,
: torch::lazy::Node(op, operands, std::move(shapes), num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
dag_hash_(GetOperandHashes(operands, node_hash_)) {}
dag_hash_(GetOperandHashes(operands, node_hash_)),
max_call_stack_depth_(0) {}

XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands,
std::vector<torch::lazy::Shape>&& shapes,
const std::function<xla::Shape()>& xla_shape_fn,
size_t num_outputs, torch::lazy::hash_t hash_seed)
: torch::lazy::Node(op, operands, std::move(shapes), num_outputs),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
dag_hash_(GetOperandHashes(operands, node_hash_)) {
dag_hash_(GetOperandHashes(operands, node_hash_)),
max_call_stack_depth_(0) {
xla_shape_ = GetOpShape(xla_shape_fn);
}

Expand All @@ -68,7 +70,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands,
num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(torch::lazy::HashCombine(op.hash(), hash_seed)),
dag_hash_(GetOperandHashes(operands, node_hash_)) {}
dag_hash_(GetOperandHashes(operands, node_hash_)),
max_call_stack_depth_(0) {}

XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::OpList operands,
xla::Shape xla_shape, size_t num_outputs,
Expand Down Expand Up @@ -102,7 +105,8 @@ XlaNode::XlaNode(torch::lazy::OpKind op, torch::lazy::Shape shape,
: torch::lazy::Node(op, shape, num_outputs),
xla_shape_(std::move(xla_shape)),
node_hash_(GetOpHash(op, xla_shape_, hash_seed)),
dag_hash_(node_hash_) {}
dag_hash_(node_hash_),
max_call_stack_depth_(0) {}

XlaNode::XlaNode(torch::lazy::OpKind op, xla::Shape xla_shape,
size_t num_outputs, torch::lazy::hash_t hash_seed)
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ class XlaNode : public torch::lazy::Node {
void SetCustomOpName(const std::string& op_name);
const std::string& custom_op_name() const { return custom_op_name_; }

void SetCustomCallStackDepth(size_t max_call_stack_depth) {
max_call_stack_depth_ = max_call_stack_depth;
}
const size_t max_call_stack_depth() const { return max_call_stack_depth_; }

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

Expand All @@ -172,6 +177,7 @@ class XlaNode : public torch::lazy::Node {
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;

std::string custom_op_name_;
size_t max_call_stack_depth_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
12 changes: 11 additions & 1 deletion torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <torch/csrc/lazy/core/ir_metadata.h>

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string_view>
Expand Down Expand Up @@ -62,8 +63,14 @@ class HloMetadataSetter {

const XlaNode* xla_node_cast = dynamic_cast<const XlaNode*>(node);

size_t max_stack_depth = nmeta.frame_info.size();

if (xla_node_cast != nullptr && !xla_node_cast->custom_op_name().empty()) {
op_name_prefix = xla_node_cast->custom_op_name();

if (xla_node_cast->max_call_stack_depth() != 0) {
max_stack_depth = xla_node_cast->max_call_stack_depth();
}
}

if (!nmeta.scope.empty()) {
Expand All @@ -75,9 +82,12 @@ class HloMetadataSetter {
if (!nmeta.frame_info.empty()) {
auto frame_it = nmeta.frame_info.rbegin();
int parent_frame_id = kInvalidIndex;
for (; frame_it != nmeta.frame_info.rend(); ++frame_it) {
int depth = 0;
for (; frame_it != nmeta.frame_info.rend() && depth <= max_stack_depth;
++frame_it) {
parent_frame_id =
loctx->AddStackFrameLocation(*frame_it, parent_frame_id);
++depth;
}

// Point to first entry / deepest call / top frame in call stack
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,4 +912,20 @@ const std::string& XLATensor::GetCustomOpName() const {
}
}

void XLATensor::SetCustomCallStackDepth(size_t max_call_stack_depth) {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node != nullptr) {
xla_node->SetCustomCallStackDepth(max_call_stack_depth);
}
}

size_t XLATensor::GetCustomCallStackDepth() const {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node != nullptr) {
return xla_node->max_call_stack_depth();
} else {
return size_t(0);
}
}

} // namespace torch_xla
6 changes: 6 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ class XLATensor : public torch::lazy::LazyTensor {
void SetCustomOpName(const std::string& op_name);
const std::string& GetCustomOpName() const;

// When using TorchDispatch - e.g. to set a custom op name we end up
// adding additional frames in stack frame debug - this limits
// stack depth
void SetCustomCallStackDepth(size_t max_call_stack_depth);
size_t GetCustomCallStackDepth() const;

private:
XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device);
XLATensor(torch::lazy::BackendDataPtr handle,
Expand Down

0 comments on commit 20532a9

Please sign in to comment.