Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend HLO metadata to include class hierarchy information #5715

Merged
merged 5 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions test/custom_debug_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import torch
import torch_xla

import inspect
from collections import defaultdict

from torch.utils._python_dispatch import TorchDispatchMode

class_count = defaultdict(int)
instance_count = dict()


def GetInstancePlaceHolder(class_type, obj):
global class_count
global instance_count

if (class_type, id(obj)) not in instance_count:
class_count[class_type] += 1
instance_count[(class_type, id(obj))] = class_count[class_type]

place_holder = instance_count[(class_type, id(obj))]

return f".{place_holder}"


def CheckIgnored(key):
ignored_list = ("self", "_bootstrap", "_fix_up_module",
"_get_supported_file_loaders", "_setup", "_buffers",
"_parameters", "_non_persistent_buffers_set")

return (key.startswith("__") and key.endswith("__")) or key in ignored_list


def Prefix(prefix, val):
if len(prefix) > 0:
return f"{prefix}.{val}"
else:
return f"{val}"


def ReverseSearchBreadthFirst(container, obj, debug=False):
if container is None:
return False

queue = []
visited = set()
nested_name = ""
max_depth = 5
queue.append((0, nested_name, container))

while len(queue):
depth, prefix, candidate = queue.pop(0)

if depth > max_depth or id(candidate) in visited:
continue

visited.add(id(candidate))

if isinstance(candidate, dict):
for k, v in candidate.items():
if not isinstance(k, str):
if debug:
print(f"Found non string key {k}")
break
if CheckIgnored(k):
continue
nested_name = Prefix(prefix, k)
if v is obj:
if debug:
print(f"Found {nested_name}")
return True, nested_name
elif debug:
print(f"Miss {nested_name}")
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
elif isinstance(candidate, (list, tuple)):
for i, v in enumerate(candidate):
nested_name = Prefix(prefix, i)
if v is obj:
if debug:
print(f"Found {nested_name}")
return True, nested_name
elif debug:
print(f"Miss {nested_name}")
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
elif hasattr(candidate, "__class__"):
# Ignore class wich overrides __getattr__ and
# generates error
if type(candidate).__name__ == "_ClassNamespace":
continue
for att in ("_modules", "__dict__"):
if hasattr(candidate, att):
v = getattr(candidate, att)
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
else:
print("No action")

return False, None


def FindMemberVariable(frame, obj):
parent_frame = frame.f_back
found = False
variable_name = None

for lframe in inspect.getouterframes(parent_frame):
if lframe.frame.f_code.co_nlocals <= 0:
continue
self_name = lframe.frame.f_code.co_varnames[0]
parent_obj = lframe.frame.f_locals[self_name]
found, variable_name = ReverseSearchBreadthFirst(parent_obj, obj)
if found:
break

return found, variable_name


def FindLocalVariable(frame, obj):
found = False
variable_name = None

for lframe in inspect.getouterframes(frame.f_back):
found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals, obj)
if found:
break

return found, variable_name


def GetClassNameAndObjFromFrame(frame):
class_obj_str = ""
if frame.f_code.co_argcount == 0:
return class_obj_str

likely_obj_name = frame.f_code.co_varnames[0]

obj = frame.f_locals[likely_obj_name]

if not hasattr(obj, "__class__") or likely_obj_name != "self":
return class_obj_str

name = type(obj).__name__
variable_name = None
found = False

found, variable_name = FindMemberVariable(frame, obj)

if not found:
found, variable_name = FindLocalVariable(frame, obj)

if not found:
variable_name = GetInstancePlaceHolder(name, obj)

name = name + "[" + variable_name + "]"

return name


def CleanNames(names):
last_name = ""
output = []
for name in names:
if name != last_name:
output.append(name)
last_name = name

# Drop the last scope which is the scope name add op_name lowerings
return output[:-1]


def GetAllObjectAndClassNames(frame):
names = []
while frame is not None:
name = GetClassNameAndObjFromFrame(frame)
if len(name) > 0:
names.append(name)
frame = frame.f_back

names.reverse()

names = CleanNames(names)

output = "/".join(names)

if len(output) > 0:
output += "/"

return output


class CustomOpNameLowering(TorchDispatchMode):

def __init__(self):
super().__init__()

def __enter__(self):
self._old_ir_debug = torch_xla._XLAC._get_ir_debug()
torch_xla._XLAC._set_ir_debug(True)
return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
torch_xla._XLAC._set_ir_debug(self._old_ir_debug)
super().__exit__(exc_type, exc_val, exc_tb)

def __torch_dispatch__(self, func, types, args=(), kwargs={}):
res = func(*args, **kwargs)
if 'xla' in str(res.device):
frame = inspect.currentframe()
prefix = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix)
return res
90 changes: 90 additions & 0 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import sys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a line in run_test.sh otherwise this test wont be run by default


# Normal imports section starts here.
import torch
import torch_xla
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
import json
from custom_debug_lowering import CustomOpNameLowering


class TestHloMetaData(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.pre_test_tensor_type = torch.get_default_dtype()
self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug()
torch.set_default_tensor_type(torch.FloatTensor)
torch_xla._XLAC._set_ir_debug(True)
super(TestHloMetaData, self).setUp()

def tearDown(self):
super(TestHloMetaData, self).tearDown()
torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug)

def test_metadata(self):
layer1 = torch.nn.Linear(4, 4)
nl1 = torch.nn.ReLU()
layer2 = torch.nn.Linear(4, 2)
nl2 = torch.nn.Tanh()
model = torch.nn.Sequential(layer1, nl1, layer2, nl2)

with CustomOpNameLowering():
model = model.to(device=xm.xla_device())
inp = torch.rand(4, 4, device=xm.xla_device())
out = model(inp)

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([out])
hlo_text = ctx.hlo_json()

# Strings to match in the lowering
bingo = {
"torch/_ops.py": False,
#"torch/nn/modules/linear.py": False,
#"torch/nn/modules/activation.py": False,
#"torch/nn/functional.py": False,
"Sequential[model]/Linear[0]": False,
"Sequential[model]/ReLU[1]": False,
"Sequential[model]/Linear[2]": False,
"Sequential[model]/Tanh[3]": False,
"aten__addmm": False,
"aten__relu": False,
"aten__tanh": False,
"aten__permute": False
}

non_zero_metadata = False

local_json = json.loads(hlo_text)
assert "computations" in local_json
for c in local_json["computations"]:
if "instructions" in c:
i = c["instructions"]
for op in i:
if 'metadata' in op:
meta = op["metadata"]
print(meta)
if len(meta) > 0:
non_zero_metadata = True
for km, vm in meta.items():
for k in bingo.keys():
if isinstance(vm, str) and k in vm:
bingo[k] = True

assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?"

for k, v in bingo.items():
assert v, f"Keyword {k} was not found as expected in HLO metadata for simple test"

print("All required metadata symbols matched")


if __name__ == '__main__':
test = unittest.main(exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
sys.exit(0 if test.result.wasSuccessful() else 1)
18 changes: 18 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,13 @@ class PyLoweringContext {
return result;
}

std::string GetHloJsonText() {
const xla::HloModuleProto& proto = computation.proto();
std::string result;
google::protobuf::util::MessageToJsonString(proto, &result);
return result;
}

private:
LoweringContext lowering_ctx;
xla::XlaComputation computation;
Expand Down Expand Up @@ -896,6 +903,7 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("build", &PyLoweringContext::Build)
.def("hlo", &PyLoweringContext::GetHlo)
.def("hlo_text", &PyLoweringContext::GetHloText)
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
Expand Down Expand Up @@ -1911,6 +1919,16 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_set_xla_custom_op_name",
[](const at::Tensor& input, const std::string& op_name) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->SetCustomOpName(op_name);
});
m.def("_get_xla_custom_op_name",
[](const at::Tensor& input) -> const std::string& {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->GetCustomOpName();
});
m.def("_get_all_reduce_token",
[](const std::string& device_str) -> const torch::lazy::Value& {
auto device = GetDeviceOrCurrent(device_str);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,8 @@ void XlaNode::UpdateShardingHash() {
}
}

void XlaNode::SetCustomOpName(const std::string& op_name) {
custom_op_name_ = op_name;
}

} // namespace torch_xla
5 changes: 5 additions & 0 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class XlaNode : public torch::lazy::Node {
return unbounded_dynamic_dims_;
}

void SetCustomOpName(const std::string& op_name);
const std::string& custom_op_name() const { return custom_op_name_; }

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

Expand All @@ -167,6 +170,8 @@ class XlaNode : public torch::lazy::Node {

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;

std::string custom_op_name_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
Loading