From 51a15d916f59d212210c569dcf5ca050e85a44fd Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Wed, 15 Nov 2023 22:23:14 +0000 Subject: [PATCH 1/5] Add python binding to allow custom op_name metadata for lowere HLO --- test/custom_debug_lowering.py | 205 ++++++++++++++++++++++++ test/test_hlo_metadata.py | 90 +++++++++++ torch_xla/csrc/init_python_bindings.cpp | 18 +++ torch_xla/csrc/ir.cpp | 4 + torch_xla/csrc/ir.h | 5 + torch_xla/csrc/lowering_context.cpp | 18 ++- torch_xla/csrc/tensor.cpp | 13 ++ torch_xla/csrc/tensor.h | 4 + 8 files changed, 349 insertions(+), 8 deletions(-) create mode 100644 test/custom_debug_lowering.py create mode 100644 test/test_hlo_metadata.py diff --git a/test/custom_debug_lowering.py b/test/custom_debug_lowering.py new file mode 100644 index 000000000000..6af2b89537a8 --- /dev/null +++ b/test/custom_debug_lowering.py @@ -0,0 +1,205 @@ +import torch +import torch_xla + +import inspect +from collections import defaultdict + +from torch.utils._python_dispatch import TorchDispatchMode + +class_count = defaultdict(int) +instance_count = dict() + +def GetInstancePlaceHolder(class_type, obj): + global class_count + global instance_count + + if (class_type,id(obj)) not in instance_count: + class_count[class_type] += 1 + instance_count[(class_type,id(obj))] = class_count[class_type] + + place_holder = instance_count[(class_type,id(obj))] + + return f".{place_holder}" + +def CheckIgnored(key): + ignored_list = ( + "self", + "_bootstrap", "_fix_up_module", "_get_supported_file_loaders", "_setup", + "_buffers", "_parameters", "_non_persistent_buffers_set" + ) + + return (key.startswith("__") and key.endswith("__")) or key in ignored_list + +def Prefix(prefix, val): + if len(prefix) > 0: + return f"{prefix}.{val}" + else: + return f"{val}" + +def ReverseSearchBreadthFirst(container,obj,debug=False): + if container is None: + return False + + queue = [] + visited = set() + nested_name = "" + max_depth = 5 + queue.append((0,nested_name,container)) + + while len(queue): + depth, prefix, candidate = queue.pop(0) + + if depth > max_depth or id(candidate) in visited: + continue + + visited.add(id(candidate)) + + if isinstance(candidate, dict): + for k,v in candidate.items(): + if not isinstance(k,str): + if debug: + print(f"Found non string key {k}") + break + if CheckIgnored(k): + continue + nested_name = Prefix(prefix,k) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth+1,nested_name,v)) + elif isinstance(candidate,(list,tuple)): + for i,v in enumerate(candidate): + nested_name = Prefix(prefix,i) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth+1,nested_name,v)) + elif hasattr(candidate, "__class__"): + # Ignore class wich overrides __getattr__ and + # generates error + if type(candidate).__name__ == "_ClassNamespace": + continue + for att in ("_modules","__dict__"): + if hasattr(candidate, att): + v = getattr(candidate, att) + if id(v) not in visited and depth < max_depth: + queue.append((depth+1,nested_name,v)) + else: + print("No action") + + return False, None + +def FindMemberVariable(frame, obj): + parent_frame = frame.f_back + found = False + variable_name = None + + for lframe in inspect.getouterframes(parent_frame): + if lframe.frame.f_code.co_nlocals <= 0: + continue + self_name = lframe.frame.f_code.co_varnames[0] + parent_obj = lframe.frame.f_locals[self_name] + found, variable_name = ReverseSearchBreadthFirst(parent_obj,obj) + if found: + break + + return found, variable_name + +def FindLocalVariable(frame, obj): + found = False + variable_name = None + + for lframe in inspect.getouterframes(frame.f_back): + found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals,obj) + if found: + break + + return found, variable_name + +def GetClassNameAndObjFromFrame( frame ): + class_obj_str = "" + if frame.f_code.co_argcount == 0: + return class_obj_str + + likely_obj_name = frame.f_code.co_varnames[0] + + obj = frame.f_locals[likely_obj_name] + + if not hasattr(obj, "__class__") or likely_obj_name != "self": + return class_obj_str + + name = type(obj).__name__ + variable_name = None + found = False + + found, variable_name = FindMemberVariable(frame, obj) + + if not found: + found, variable_name = FindLocalVariable(frame, obj) + + if not found: + variable_name = GetInstancePlaceHolder(name, obj) + + name = name + "[" + variable_name + "]" + + return name + +def CleanNames( names ): + last_name = "" + output = [] + for name in names: + if name != last_name: + output.append(name) + last_name = name + + # Drop the last scope which is the scope name add op_name lowerings + return output[:-1] + +def GetAllObjectAndClassNames( frame ): + names = [] + while frame is not None: + name = GetClassNameAndObjFromFrame(frame) + if len(name) > 0: + names.append(name) + frame = frame.f_back + + names.reverse() + + names = CleanNames(names) + + output = "/".join(names) + + if len(output) > 0: + output += "/" + + return output + +class CustomOpNameLowering(TorchDispatchMode): + + def __init__(self): + super().__init__() + + def __enter__(self): + self._old_ir_debug = torch_xla._XLAC._get_ir_debug() + torch_xla._XLAC._set_ir_debug(True) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_xla._XLAC._set_ir_debug(self._old_ir_debug) + super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_dispatch__(self, func, types, args=(), kwargs={}): + res = func(*args, **kwargs) + if 'xla' in str(res.device): + frame = inspect.currentframe() + prefix = GetAllObjectAndClassNames(frame) + torch_xla._XLAC._set_xla_custom_op_name(res, prefix) + return res diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py new file mode 100644 index 000000000000..92022f4f7eb6 --- /dev/null +++ b/test/test_hlo_metadata.py @@ -0,0 +1,90 @@ +import sys + +# Normal imports section starts here. +import torch +import torch_xla +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import unittest +import json +from custom_debug_lowering import CustomOpNameLowering + +class TestHloMetaData(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.pre_test_tensor_type = torch.get_default_dtype() + self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug() + torch.set_default_tensor_type(torch.FloatTensor) + torch_xla._XLAC._set_ir_debug(True) + super(TestHloMetaData, self).setUp() + + def tearDown(self): + super(TestHloMetaData, self).tearDown() + torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug) + + def test_metadata(self): + layer1 = torch.nn.Linear(4, 4) + nl1 = torch.nn.ReLU() + layer2 = torch.nn.Linear(4, 2) + nl2 = torch.nn.Tanh() + model = torch.nn.Sequential(layer1, nl1, layer2, nl2) + + with CustomOpNameLowering(): + model = model.to(device=xm.xla_device()) + inp = torch.rand(4, 4, device=xm.xla_device()) + out = model(inp) + + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([out]) + hlo_text = ctx.hlo_json() + + # Strings to match in the lowering + bingo = { + "torch/_ops.py": False, + #"torch/nn/modules/linear.py": False, + #"torch/nn/modules/activation.py": False, + #"torch/nn/functional.py": False, + "Sequential[model]/Linear[0]": False, + "Sequential[model]/ReLU[1]": False, + "Sequential[model]/Linear[2]": False, + "Sequential[model]/Tanh[3]": False, + "aten__addmm": False, + "aten__relu": False, + "aten__tanh": False, + "aten__permute": False + } + + non_zero_metadata = False + + local_json = json.loads(hlo_text) + assert "computations" in local_json + for c in local_json["computations"]: + if "instructions" in c: + i = c["instructions"] + for op in i: + if 'metadata' in op: + meta = op["metadata"] + print(meta) + if len(meta) > 0: + non_zero_metadata = True + for km, vm in meta.items(): + for k in bingo.keys(): + if isinstance(vm, str) and k in vm: + bingo[k] = True + + assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?" + + for k, v in bingo.items(): + assert v, f"Keyword {k} was not found as expected in HLO metadata for simple test" + + print("All required metadata symbols matched") + + +if __name__ == '__main__': + test = unittest.main(exit=False) + if xu.getenv_as('METRICS_DEBUG', bool, defval=False): + print(met.metrics_report()) + sys.exit(0 if test.result.wasSuccessful() else 1) + diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8c45d68f8029..babf431413e2 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -855,6 +855,13 @@ class PyLoweringContext { return result; } + std::string GetHloJsonText() { + const xla::HloModuleProto& proto = computation.proto(); + std::string result; + google::protobuf::util::MessageToJsonString(proto, &result); + return result; + } + private: LoweringContext lowering_ctx; xla::XlaComputation computation; @@ -896,6 +903,7 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("build", &PyLoweringContext::Build) .def("hlo", &PyLoweringContext::GetHlo) .def("hlo_text", &PyLoweringContext::GetHloText) + .def("hlo_json", &PyLoweringContext::GetHloJsonText) .def("parameter_id_tensor_mapping", &PyLoweringContext::GetParameterIdTensorMapping) .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); @@ -1911,6 +1919,16 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); + m.def("_set_xla_custom_op_name", + [](const at::Tensor& input, const std::string& op_name) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->SetCustomOpName(op_name); + }); + m.def("_get_xla_custom_op_name", + [](const at::Tensor& input) -> const std::string& { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + return xtensor->GetCustomOpName(); + }); m.def("_get_all_reduce_token", [](const std::string& device_str) -> const torch::lazy::Value& { auto device = GetDeviceOrCurrent(device_str); diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 82b746ab1813..b7cd2025bd34 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -230,4 +230,8 @@ void XlaNode::UpdateShardingHash() { } } +void XlaNode::SetCustomOpName(const std::string& op_name) { + custom_op_name_ = op_name; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index d0619ef5c987..1e4a0439e235 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -146,6 +146,9 @@ class XlaNode : public torch::lazy::Node { return unbounded_dynamic_dims_; } + void SetCustomOpName(const std::string& op_name); + const std::string& custom_op_name() const { return custom_op_name_; } + protected: std::unordered_set unbounded_dynamic_dims_; @@ -167,6 +170,8 @@ class XlaNode : public torch::lazy::Node { // Experimental sharding annotations attached to the IR node. std::vector> output_shardings_; + + std::string custom_op_name_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 404fa82ea7b1..622e2ef7dd96 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -38,7 +38,7 @@ class HloMetadataSetter { static bool ShouldPopulateXlaOpMetadata() { static bool op_metadata = runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false); - return op_metadata; + return FLAGS_torch_lazy_ir_debug || op_metadata; } static void PopulateXlaOpMetadata(LoweringContext* loctx, @@ -53,6 +53,13 @@ class HloMetadataSetter { metadata.set_op_type(op_type); const torch::lazy::MetaData& nmeta = node->metadata(); std::string op_name_prefix; + + const XlaNode* xla_node_cast = dynamic_cast(node); + + if (xla_node_cast != nullptr && !xla_node_cast->custom_op_name().empty()) { + op_name_prefix = xla_node_cast->custom_op_name(); + } + if (!nmeta.scope.empty()) { op_name_prefix = absl::StrCat(absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/"); @@ -61,13 +68,8 @@ class HloMetadataSetter { if (!nmeta.frame_info.empty()) { const torch::lazy::SourceLocation& frame = nmeta.frame_info.front(); - std::string::size_type pos = frame.file.find_last_of('/'); - if (pos == std::string::npos) { - pos = 0; - } else { - ++pos; - } - metadata.set_source_file(frame.function + "@" + frame.file.substr(pos)); + + metadata.set_source_file(frame.file); metadata.set_source_line(frame.line); } loctx->builder()->SetOpMetadata(std::move(metadata)); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index b30cbe7c01ec..fa66aec1e2a3 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -897,4 +897,17 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { xla_node->MarkDynamicDimension(dim); } +void XLATensor::SetCustomOpName(const std::string& op_name) { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) xla_node->SetCustomOpName(op_name); +} + +const std::string& XLATensor::GetCustomOpName() const { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) + return xla_node->custom_op_name(); + else + return ""; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index f73aed5ce5fc..43849458068b 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -282,6 +282,10 @@ class XLATensor : public torch::lazy::LazyTensor { // Override to enable SPMD. void AssignIrValue(torch::lazy::Value ir_value) const final; + // Set custom op name on XlaNode + void SetCustomOpName(const std::string& op_name); + const std::string& GetCustomOpName() const; + private: XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device); XLATensor(torch::lazy::BackendDataPtr handle, From 7d39252219ea0db0c32d6308f9bbbfdcce9ae614 Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Wed, 15 Nov 2023 22:25:25 +0000 Subject: [PATCH 2/5] Lint fix for clang-format --- torch_xla/csrc/tensor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 43849458068b..83db2e95df61 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -284,7 +284,7 @@ class XLATensor : public torch::lazy::LazyTensor { // Set custom op name on XlaNode void SetCustomOpName(const std::string& op_name); - const std::string& GetCustomOpName() const; + const std::string& GetCustomOpName() const; private: XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device); From b9eb7dc651690eb4eceebe415a6f77cd6c9075fe Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Wed, 15 Nov 2023 22:29:58 +0000 Subject: [PATCH 3/5] Lint fix for customer op_name test code --- test/custom_debug_lowering.py | 326 +++++++++++++++++----------------- test/test_hlo_metadata.py | 8 +- 2 files changed, 171 insertions(+), 163 deletions(-) diff --git a/test/custom_debug_lowering.py b/test/custom_debug_lowering.py index 6af2b89537a8..cc19dd79b81f 100644 --- a/test/custom_debug_lowering.py +++ b/test/custom_debug_lowering.py @@ -9,197 +9,205 @@ class_count = defaultdict(int) instance_count = dict() + def GetInstancePlaceHolder(class_type, obj): - global class_count - global instance_count + global class_count + global instance_count + + if (class_type, id(obj)) not in instance_count: + class_count[class_type] += 1 + instance_count[(class_type, id(obj))] = class_count[class_type] - if (class_type,id(obj)) not in instance_count: - class_count[class_type] += 1 - instance_count[(class_type,id(obj))] = class_count[class_type] + place_holder = instance_count[(class_type, id(obj))] - place_holder = instance_count[(class_type,id(obj))] + return f".{place_holder}" - return f".{place_holder}" def CheckIgnored(key): - ignored_list = ( - "self", - "_bootstrap", "_fix_up_module", "_get_supported_file_loaders", "_setup", - "_buffers", "_parameters", "_non_persistent_buffers_set" - ) + ignored_list = ("self", "_bootstrap", "_fix_up_module", + "_get_supported_file_loaders", "_setup", "_buffers", + "_parameters", "_non_persistent_buffers_set") + + return (key.startswith("__") and key.endswith("__")) or key in ignored_list - return (key.startswith("__") and key.endswith("__")) or key in ignored_list def Prefix(prefix, val): - if len(prefix) > 0: - return f"{prefix}.{val}" - else: - return f"{val}" - -def ReverseSearchBreadthFirst(container,obj,debug=False): - if container is None: - return False - - queue = [] - visited = set() - nested_name = "" - max_depth = 5 - queue.append((0,nested_name,container)) - - while len(queue): - depth, prefix, candidate = queue.pop(0) - - if depth > max_depth or id(candidate) in visited: - continue - - visited.add(id(candidate)) - - if isinstance(candidate, dict): - for k,v in candidate.items(): - if not isinstance(k,str): - if debug: - print(f"Found non string key {k}") - break - if CheckIgnored(k): - continue - nested_name = Prefix(prefix,k) - if v is obj: - if debug: - print(f"Found {nested_name}") - return True, nested_name - elif debug: - print(f"Miss {nested_name}") - if id(v) not in visited and depth < max_depth: - queue.append((depth+1,nested_name,v)) - elif isinstance(candidate,(list,tuple)): - for i,v in enumerate(candidate): - nested_name = Prefix(prefix,i) - if v is obj: - if debug: - print(f"Found {nested_name}") - return True, nested_name - elif debug: - print(f"Miss {nested_name}") - if id(v) not in visited and depth < max_depth: - queue.append((depth+1,nested_name,v)) - elif hasattr(candidate, "__class__"): - # Ignore class wich overrides __getattr__ and - # generates error - if type(candidate).__name__ == "_ClassNamespace": - continue - for att in ("_modules","__dict__"): - if hasattr(candidate, att): - v = getattr(candidate, att) - if id(v) not in visited and depth < max_depth: - queue.append((depth+1,nested_name,v)) - else: - print("No action") - - return False, None + if len(prefix) > 0: + return f"{prefix}.{val}" + else: + return f"{val}" + + +def ReverseSearchBreadthFirst(container, obj, debug=False): + if container is None: + return False + + queue = [] + visited = set() + nested_name = "" + max_depth = 5 + queue.append((0, nested_name, container)) + + while len(queue): + depth, prefix, candidate = queue.pop(0) + + if depth > max_depth or id(candidate) in visited: + continue + + visited.add(id(candidate)) + + if isinstance(candidate, dict): + for k, v in candidate.items(): + if not isinstance(k, str): + if debug: + print(f"Found non string key {k}") + break + if CheckIgnored(k): + continue + nested_name = Prefix(prefix, k) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + elif isinstance(candidate, (list, tuple)): + for i, v in enumerate(candidate): + nested_name = Prefix(prefix, i) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + elif hasattr(candidate, "__class__"): + # Ignore class wich overrides __getattr__ and + # generates error + if type(candidate).__name__ == "_ClassNamespace": + continue + for att in ("_modules", "__dict__"): + if hasattr(candidate, att): + v = getattr(candidate, att) + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + else: + print("No action") + + return False, None + def FindMemberVariable(frame, obj): - parent_frame = frame.f_back - found = False - variable_name = None + parent_frame = frame.f_back + found = False + variable_name = None + + for lframe in inspect.getouterframes(parent_frame): + if lframe.frame.f_code.co_nlocals <= 0: + continue + self_name = lframe.frame.f_code.co_varnames[0] + parent_obj = lframe.frame.f_locals[self_name] + found, variable_name = ReverseSearchBreadthFirst(parent_obj, obj) + if found: + break - for lframe in inspect.getouterframes(parent_frame): - if lframe.frame.f_code.co_nlocals <= 0: - continue - self_name = lframe.frame.f_code.co_varnames[0] - parent_obj = lframe.frame.f_locals[self_name] - found, variable_name = ReverseSearchBreadthFirst(parent_obj,obj) - if found: - break + return found, variable_name - return found, variable_name def FindLocalVariable(frame, obj): - found = False - variable_name = None + found = False + variable_name = None + + for lframe in inspect.getouterframes(frame.f_back): + found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals, obj) + if found: + break + + return found, variable_name + + +def GetClassNameAndObjFromFrame(frame): + class_obj_str = "" + if frame.f_code.co_argcount == 0: + return class_obj_str + + likely_obj_name = frame.f_code.co_varnames[0] + + obj = frame.f_locals[likely_obj_name] + + if not hasattr(obj, "__class__") or likely_obj_name != "self": + return class_obj_str - for lframe in inspect.getouterframes(frame.f_back): - found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals,obj) - if found: - break + name = type(obj).__name__ + variable_name = None + found = False - return found, variable_name + found, variable_name = FindMemberVariable(frame, obj) -def GetClassNameAndObjFromFrame( frame ): - class_obj_str = "" - if frame.f_code.co_argcount == 0: - return class_obj_str - - likely_obj_name = frame.f_code.co_varnames[0] - - obj = frame.f_locals[likely_obj_name] + if not found: + found, variable_name = FindLocalVariable(frame, obj) - if not hasattr(obj, "__class__") or likely_obj_name != "self": - return class_obj_str - - name = type(obj).__name__ - variable_name = None - found = False + if not found: + variable_name = GetInstancePlaceHolder(name, obj) - found, variable_name = FindMemberVariable(frame, obj) + name = name + "[" + variable_name + "]" - if not found: - found, variable_name = FindLocalVariable(frame, obj) + return name - if not found: - variable_name = GetInstancePlaceHolder(name, obj) - name = name + "[" + variable_name + "]" +def CleanNames(names): + last_name = "" + output = [] + for name in names: + if name != last_name: + output.append(name) + last_name = name - return name + # Drop the last scope which is the scope name add op_name lowerings + return output[:-1] -def CleanNames( names ): - last_name = "" - output = [] - for name in names: - if name != last_name: - output.append(name) - last_name = name - # Drop the last scope which is the scope name add op_name lowerings - return output[:-1] +def GetAllObjectAndClassNames(frame): + names = [] + while frame is not None: + name = GetClassNameAndObjFromFrame(frame) + if len(name) > 0: + names.append(name) + frame = frame.f_back -def GetAllObjectAndClassNames( frame ): - names = [] - while frame is not None: - name = GetClassNameAndObjFromFrame(frame) - if len(name) > 0: - names.append(name) - frame = frame.f_back + names.reverse() - names.reverse() + names = CleanNames(names) - names = CleanNames(names) + output = "/".join(names) - output = "/".join(names) + if len(output) > 0: + output += "/" - if len(output) > 0: - output += "/" + return output - return output class CustomOpNameLowering(TorchDispatchMode): - def __init__(self): - super().__init__() - - def __enter__(self): - self._old_ir_debug = torch_xla._XLAC._get_ir_debug() - torch_xla._XLAC._set_ir_debug(True) - return super().__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - torch_xla._XLAC._set_ir_debug(self._old_ir_debug) - super().__exit__(exc_type, exc_val, exc_tb) - - def __torch_dispatch__(self, func, types, args=(), kwargs={}): - res = func(*args, **kwargs) - if 'xla' in str(res.device): - frame = inspect.currentframe() - prefix = GetAllObjectAndClassNames(frame) - torch_xla._XLAC._set_xla_custom_op_name(res, prefix) - return res + def __init__(self): + super().__init__() + + def __enter__(self): + self._old_ir_debug = torch_xla._XLAC._get_ir_debug() + torch_xla._XLAC._set_ir_debug(True) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_xla._XLAC._set_ir_debug(self._old_ir_debug) + super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_dispatch__(self, func, types, args=(), kwargs={}): + res = func(*args, **kwargs) + if 'xla' in str(res.device): + frame = inspect.currentframe() + prefix = GetAllObjectAndClassNames(frame) + torch_xla._XLAC._set_xla_custom_op_name(res, prefix) + return res diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 92022f4f7eb6..5f5ac186395d 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -10,6 +10,7 @@ import json from custom_debug_lowering import CustomOpNameLowering + class TestHloMetaData(unittest.TestCase): def setUp(self): @@ -32,9 +33,9 @@ def test_metadata(self): model = torch.nn.Sequential(layer1, nl1, layer2, nl2) with CustomOpNameLowering(): - model = model.to(device=xm.xla_device()) - inp = torch.rand(4, 4, device=xm.xla_device()) - out = model(inp) + model = model.to(device=xm.xla_device()) + inp = torch.rand(4, 4, device=xm.xla_device()) + out = model(inp) ctx = torch_xla._XLAC.lowering.LoweringContext() ctx.build([out]) @@ -87,4 +88,3 @@ def test_metadata(self): if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(met.metrics_report()) sys.exit(0 if test.result.wasSuccessful() else 1) - From c5f74b12ef6978ced9882e09cb06064a8e3e64d0 Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Fri, 17 Nov 2023 00:23:51 +0000 Subject: [PATCH 4/5] Requested nit changes --- test/run_tests.sh | 1 + torch_xla/csrc/tensor.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index 31b30fa95eeb..4777681013bb 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -163,6 +163,7 @@ function run_xla_op_tests1 { run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/test_pt_xla_debug.py" run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_hlo_metadata.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/pjrt/test_runtime.py" run_test "$CDIR/pjrt/test_runtime_gpu.py" diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index fa66aec1e2a3..85cd70cc8cfa 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -899,15 +899,18 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { void XLATensor::SetCustomOpName(const std::string& op_name) { auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); - if (xla_node != nullptr) xla_node->SetCustomOpName(op_name); + if (xla_node != nullptr) { + xla_node->SetCustomOpName(op_name); + } } const std::string& XLATensor::GetCustomOpName() const { auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); - if (xla_node != nullptr) + if (xla_node != nullptr) { return xla_node->custom_op_name(); - else + } else { return ""; + } } } // namespace torch_xla From 457dd85d9bbc7ec90ebbc9acdc7a0f84f4baaf0c Mon Sep 17 00:00:00 2001 From: Nicholas Waldron Date: Fri, 17 Nov 2023 23:25:42 +0000 Subject: [PATCH 5/5] As discussed increase timeout on GPU tests by 20% --- test/cpp/run_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 16bba61b1954..742443228406 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -105,9 +105,9 @@ fi for name in "${test_names[@]}"; do echo "Running $name cpp test..." if [ "$LOGFILE" != "" ]; then - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"} 2> $LOGFILE + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} 2> $LOGFILE else - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"} + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} fi done