Skip to content

Commit

Permalink
Add stack frame id metadata and correct filename and line number for …
Browse files Browse the repository at this point in the history
…custom op_name (pytorch#5838)

* Add python binding to allow custom op_name metadata for lowere HLO

* As discussed increase timeout on GPU tests by 20%

* Add lowering for stack frame index and stack frame id in metadata

* Add fix for stack depth when using set custom op_name in a python context

* Changes after adding tests for lowered stack frames and finding several issues

* Add routine to XlaNode to search back through operands and recusively set meta data

* Fix recursion condition so we don't explore nodes with metadata
  • Loading branch information
mrnikwaws authored and ManfeiBai committed Dec 1, 2023
1 parent 8c875aa commit 9c35aee
Show file tree
Hide file tree
Showing 12 changed files with 380 additions and 54 deletions.
61 changes: 58 additions & 3 deletions test/custom_debug_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
class_count = defaultdict(int)
instance_count = dict()

# This is a sample implementation for readying object
# hierachies from a source stack usng a TorchDispatch
# interceptor. We then set the node op_name in XLA
# via the output tensor and direct XLA to ignore stack
# frames added (due to TorchDispatch) during lowering


def GetInstancePlaceHolder(class_type, obj):
global class_count
Expand Down Expand Up @@ -172,11 +178,21 @@ def CleanNames(names):

def GetAllObjectAndClassNames(frame):
names = []
frame_count = 0
self_found = False
while frame is not None:
if __file__ == frame.f_code.co_filename:
self_found = True

if not self_found:
frame = frame.f_back
continue

name = GetClassNameAndObjFromFrame(frame)
if len(name) > 0:
names.append(name)
frame = frame.f_back
frame_count += 1

names.reverse()

Expand All @@ -187,7 +203,24 @@ def GetAllObjectAndClassNames(frame):
if len(output) > 0:
output += "/"

return output
return output, frame_count - 1


class StackLayerSignature:

def __init__(self, filename, func, line):
self.filename = filename
self.func = func
self.line = line

def __str__(self):
return f"{self.filename}|{self.func}|{self.line}"

def __repr__(self):
return str(self)

def __eq__(self, ref):
return self.filename == ref.filename and self.func == ref.func and self.line == ref.line


class CustomOpNameLowering(TorchDispatchMode):
Expand All @@ -198,16 +231,38 @@ def __init__(self):
def __enter__(self):
self._old_ir_debug = torch_xla._XLAC._get_ir_debug()
torch_xla._XLAC._set_ir_debug(True)
self.stack_sigs = []
return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
torch_xla._XLAC._set_ir_debug(self._old_ir_debug)
del self.stack_sigs
super().__exit__(exc_type, exc_val, exc_tb)

def add_stack_sig(self, frame, depth):
stack = []
for s in inspect.getouterframes(frame):
sls = StackLayerSignature(s.filename, s.function, s.lineno)
stack.append(sls)

# Pop the top two stack laters
while len(stack) > depth:
stack.pop(0)

assert len(stack) == depth

self.stack_sigs.append(stack)

return stack

def __torch_dispatch__(self, func, types, args=(), kwargs={}):
res = func(*args, **kwargs)
if 'xla' in str(res.device):
frame = inspect.currentframe()
prefix = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix)
prefix, depth = GetAllObjectAndClassNames(frame)
self.depth = depth
self.add_stack_sig(frame, self.depth)

assert torch_xla._XLAC._set_xla_custom_op_name_prefix(
res, prefix, self.depth), "Custom op set failed"
return res
89 changes: 83 additions & 6 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,52 @@
import torch_xla.debug.metrics as met
import unittest
import json
from custom_debug_lowering import CustomOpNameLowering
import inspect
import copy
from custom_debug_lowering import CustomOpNameLowering, StackLayerSignature


class HloStackExtractor:

def __init__(self, hlo_json):
assert 'stackFrameIndex' in hlo_json
assert 'fileLocations' in hlo_json['stackFrameIndex']
assert 'stackFrames' in hlo_json['stackFrameIndex']
assert 'fileNames' in hlo_json['stackFrameIndex']
assert 'functionNames' in hlo_json['stackFrameIndex']

self.file_locations = hlo_json['stackFrameIndex']['fileLocations']
self.stack_frames = hlo_json['stackFrameIndex']['stackFrames']
self.file_names = hlo_json['stackFrameIndex']['fileNames']
self.function_names = hlo_json['stackFrameIndex']['functionNames']

def extract(self, stack_frame_id):
stack_sigs = []

stack_frame = self.stack_frames[stack_frame_id - 1]

while True:
file_location_id = stack_frame['fileLocationId']
file_location = self.file_locations[file_location_id - 1]
file_name_id = file_location['fileNameId']
function_name_id = file_location['functionNameId']
line = file_location['line']
file_name = self.file_names[file_name_id - 1]
function_name = self.function_names[function_name_id - 1]

sig = StackLayerSignature(file_name, function_name, line)
stack_sigs.append(sig)

stack_frame_id = 0
if 'parentFrameId' in stack_frame:
stack_frame_id = stack_frame['parentFrameId']

if stack_frame_id == 0:
break
else:
stack_frame = self.stack_frames[stack_frame_id - 1]

return stack_sigs


class TestHloMetaData(unittest.TestCase):
Expand All @@ -32,21 +77,25 @@ def test_metadata(self):
nl2 = torch.nn.Tanh()
model = torch.nn.Sequential(layer1, nl1, layer2, nl2)

with CustomOpNameLowering():
with CustomOpNameLowering() as c:
model = model.to(device=xm.xla_device())
inp = torch.rand(4, 4, device=xm.xla_device())
#inp = torch.rand(4, 4)
#inp = inp.to(device=xm.xla_device())
out = model(inp)

# Get outer frames
stack_sigs = c.stack_sigs

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([out])
hlo_text = ctx.hlo_json()

# Strings to match in the lowering
bingo = {
"torch/_ops.py": False,
#"torch/nn/modules/linear.py": False,
#"torch/nn/modules/activation.py": False,
#"torch/nn/functional.py": False,
"torch/nn/modules/linear.py": False,
"torch/nn/modules/activation.py": False,
"torch/nn/functional.py": False,
"Sequential[model]/Linear[0]": False,
"Sequential[model]/ReLU[1]": False,
"Sequential[model]/Linear[2]": False,
Expand All @@ -60,10 +109,17 @@ def test_metadata(self):
non_zero_metadata = False

local_json = json.loads(hlo_text)

#with open("./hlo.json", "w") as f:
# f.write(json.dumps(local_json, indent=2))

hloEx = HloStackExtractor(local_json)

assert "computations" in local_json
for c in local_json["computations"]:
if "instructions" in c:
i = c["instructions"]

for op in i:
if 'metadata' in op:
meta = op["metadata"]
Expand All @@ -75,6 +131,27 @@ def test_metadata(self):
if isinstance(vm, str) and k in vm:
bingo[k] = True

# Decode stack frame id and check it matches one of the
# the passed in stacks
stack_frame_match = False
if 'stackFrameId' in meta:
hlo_stack_sig = hloEx.extract(meta['stackFrameId'])

for t_sig in stack_sigs:
if len(hlo_stack_sig) == len(t_sig) and hlo_stack_sig == t_sig:
stack_frame_match = True
break
elif len(hlo_stack_sig) > len(t_sig):
hlo_stack_sig_copy = copy.copy(hlo_stack_sig)
discards = []
while len(hlo_stack_sig_copy) > len(t_sig):
discards.append(hlo_stack_sig_copy.pop(0))
# Print an error message on a partial match
if hlo_stack_sig_copy == t_sig:
print(f"** PARTIAL MATCH: Discarded {discards}")

assert stack_frame_match, f"Stack\n{hlo_stack_sig} does not match any of\n{stack_sigs}"

assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?"

for k, v in bingo.items():
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ ptxla_cc_library(
srcs = [
"ir.cpp",
"lowering_context.cpp",
"stack_frame_index_builder.cpp",
],
hdrs = [
"ir.h",
"lowering_context.h",
"stack_frame_index_builder.h",
],
deps = [
":device",
Expand Down
15 changes: 7 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1967,15 +1967,14 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_set_xla_custom_op_name",
[](const at::Tensor& input, const std::string& op_name) {
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
size_t max_call_stack_depth) -> bool {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->SetCustomOpName(op_name);
});
m.def("_get_xla_custom_op_name",
[](const at::Tensor& input) -> const std::string& {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->GetCustomOpName();
std::shared_ptr<torch::lazy::UserMetaData> user_meta =
std::make_shared<CustomOpNameMetaData>(op_name_prefix,
max_call_stack_depth);
return xtensor->SetNodeUserMetadata(user_meta);
});
m.def("_get_all_reduce_token",
[](const std::string& device_str) -> const torch::lazy::Value& {
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,16 @@ void XlaNode::UpdateShardingHash() {
}
}

void XlaNode::SetCustomOpName(const std::string& op_name) {
custom_op_name_ = op_name;
std::shared_ptr<torch::lazy::UserMetaData> XlaNode::SetUserMetadataForSubGraph(
std::shared_ptr<torch::lazy::UserMetaData> user_meta) {
for (auto np : operands_) {
XlaNode* xnp = dynamic_cast<XlaNode*>(np.get());
if (xnp != nullptr && xnp->user_metadata() == nullptr) {
xnp->SetUserMetadataForSubGraph(user_meta);
}
}
// Only set if there is no metadata already set
return SetUserMetadata(user_meta);
}

} // namespace torch_xla
16 changes: 12 additions & 4 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/ir_metadata.h>

#include <functional>
#include <iostream>
Expand Down Expand Up @@ -146,8 +147,8 @@ class XlaNode : public torch::lazy::Node {
return unbounded_dynamic_dims_;
}

void SetCustomOpName(const std::string& op_name);
const std::string& custom_op_name() const { return custom_op_name_; }
std::shared_ptr<torch::lazy::UserMetaData> SetUserMetadataForSubGraph(
std::shared_ptr<torch::lazy::UserMetaData> user_meta);

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;
Expand All @@ -170,8 +171,6 @@ class XlaNode : public torch::lazy::Node {

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;

std::string custom_op_name_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand All @@ -195,6 +194,15 @@ T* NodeCast(const torch::lazy::Node* node, torch::lazy::OpKind op) {
return const_cast<T*>(casted);
}

struct CustomOpNameMetaData : public torch::lazy::UserMetaData {
CustomOpNameMetaData(const std::string& input_op_name_prefix,
int input_max_stack_depth)
: op_name_prefix(input_op_name_prefix),
max_stack_depth(input_max_stack_depth) {}
std::string op_name_prefix;
size_t max_stack_depth;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_IR_H_
Loading

0 comments on commit 9c35aee

Please sign in to comment.