diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 16bba61b1954..742443228406 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -105,9 +105,9 @@ fi for name in "${test_names[@]}"; do echo "Running $name cpp test..." if [ "$LOGFILE" != "" ]; then - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"} 2> $LOGFILE + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} 2> $LOGFILE else - bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"} + bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} fi done diff --git a/test/custom_debug_lowering.py b/test/custom_debug_lowering.py new file mode 100644 index 000000000000..cc19dd79b81f --- /dev/null +++ b/test/custom_debug_lowering.py @@ -0,0 +1,213 @@ +import torch +import torch_xla + +import inspect +from collections import defaultdict + +from torch.utils._python_dispatch import TorchDispatchMode + +class_count = defaultdict(int) +instance_count = dict() + + +def GetInstancePlaceHolder(class_type, obj): + global class_count + global instance_count + + if (class_type, id(obj)) not in instance_count: + class_count[class_type] += 1 + instance_count[(class_type, id(obj))] = class_count[class_type] + + place_holder = instance_count[(class_type, id(obj))] + + return f".{place_holder}" + + +def CheckIgnored(key): + ignored_list = ("self", "_bootstrap", "_fix_up_module", + "_get_supported_file_loaders", "_setup", "_buffers", + "_parameters", "_non_persistent_buffers_set") + + return (key.startswith("__") and key.endswith("__")) or key in ignored_list + + +def Prefix(prefix, val): + if len(prefix) > 0: + return f"{prefix}.{val}" + else: + return f"{val}" + + +def ReverseSearchBreadthFirst(container, obj, debug=False): + if container is None: + return False + + queue = [] + visited = set() + nested_name = "" + max_depth = 5 + queue.append((0, nested_name, container)) + + while len(queue): + depth, prefix, candidate = queue.pop(0) + + if depth > max_depth or id(candidate) in visited: + continue + + visited.add(id(candidate)) + + if isinstance(candidate, dict): + for k, v in candidate.items(): + if not isinstance(k, str): + if debug: + print(f"Found non string key {k}") + break + if CheckIgnored(k): + continue + nested_name = Prefix(prefix, k) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + elif isinstance(candidate, (list, tuple)): + for i, v in enumerate(candidate): + nested_name = Prefix(prefix, i) + if v is obj: + if debug: + print(f"Found {nested_name}") + return True, nested_name + elif debug: + print(f"Miss {nested_name}") + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + elif hasattr(candidate, "__class__"): + # Ignore class wich overrides __getattr__ and + # generates error + if type(candidate).__name__ == "_ClassNamespace": + continue + for att in ("_modules", "__dict__"): + if hasattr(candidate, att): + v = getattr(candidate, att) + if id(v) not in visited and depth < max_depth: + queue.append((depth + 1, nested_name, v)) + else: + print("No action") + + return False, None + + +def FindMemberVariable(frame, obj): + parent_frame = frame.f_back + found = False + variable_name = None + + for lframe in inspect.getouterframes(parent_frame): + if lframe.frame.f_code.co_nlocals <= 0: + continue + self_name = lframe.frame.f_code.co_varnames[0] + parent_obj = lframe.frame.f_locals[self_name] + found, variable_name = ReverseSearchBreadthFirst(parent_obj, obj) + if found: + break + + return found, variable_name + + +def FindLocalVariable(frame, obj): + found = False + variable_name = None + + for lframe in inspect.getouterframes(frame.f_back): + found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals, obj) + if found: + break + + return found, variable_name + + +def GetClassNameAndObjFromFrame(frame): + class_obj_str = "" + if frame.f_code.co_argcount == 0: + return class_obj_str + + likely_obj_name = frame.f_code.co_varnames[0] + + obj = frame.f_locals[likely_obj_name] + + if not hasattr(obj, "__class__") or likely_obj_name != "self": + return class_obj_str + + name = type(obj).__name__ + variable_name = None + found = False + + found, variable_name = FindMemberVariable(frame, obj) + + if not found: + found, variable_name = FindLocalVariable(frame, obj) + + if not found: + variable_name = GetInstancePlaceHolder(name, obj) + + name = name + "[" + variable_name + "]" + + return name + + +def CleanNames(names): + last_name = "" + output = [] + for name in names: + if name != last_name: + output.append(name) + last_name = name + + # Drop the last scope which is the scope name add op_name lowerings + return output[:-1] + + +def GetAllObjectAndClassNames(frame): + names = [] + while frame is not None: + name = GetClassNameAndObjFromFrame(frame) + if len(name) > 0: + names.append(name) + frame = frame.f_back + + names.reverse() + + names = CleanNames(names) + + output = "/".join(names) + + if len(output) > 0: + output += "/" + + return output + + +class CustomOpNameLowering(TorchDispatchMode): + + def __init__(self): + super().__init__() + + def __enter__(self): + self._old_ir_debug = torch_xla._XLAC._get_ir_debug() + torch_xla._XLAC._set_ir_debug(True) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch_xla._XLAC._set_ir_debug(self._old_ir_debug) + super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_dispatch__(self, func, types, args=(), kwargs={}): + res = func(*args, **kwargs) + if 'xla' in str(res.device): + frame = inspect.currentframe() + prefix = GetAllObjectAndClassNames(frame) + torch_xla._XLAC._set_xla_custom_op_name(res, prefix) + return res diff --git a/test/run_tests.sh b/test/run_tests.sh index 31b30fa95eeb..4777681013bb 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -163,6 +163,7 @@ function run_xla_op_tests1 { run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/test_pt_xla_debug.py" run_test "$CDIR/test_async_closures.py" + run_test "$CDIR/test_hlo_metadata.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/pjrt/test_runtime.py" run_test "$CDIR/pjrt/test_runtime_gpu.py" diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py new file mode 100644 index 000000000000..5f5ac186395d --- /dev/null +++ b/test/test_hlo_metadata.py @@ -0,0 +1,90 @@ +import sys + +# Normal imports section starts here. +import torch +import torch_xla +import torch_xla.utils.utils as xu +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import unittest +import json +from custom_debug_lowering import CustomOpNameLowering + + +class TestHloMetaData(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.pre_test_tensor_type = torch.get_default_dtype() + self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug() + torch.set_default_tensor_type(torch.FloatTensor) + torch_xla._XLAC._set_ir_debug(True) + super(TestHloMetaData, self).setUp() + + def tearDown(self): + super(TestHloMetaData, self).tearDown() + torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug) + + def test_metadata(self): + layer1 = torch.nn.Linear(4, 4) + nl1 = torch.nn.ReLU() + layer2 = torch.nn.Linear(4, 2) + nl2 = torch.nn.Tanh() + model = torch.nn.Sequential(layer1, nl1, layer2, nl2) + + with CustomOpNameLowering(): + model = model.to(device=xm.xla_device()) + inp = torch.rand(4, 4, device=xm.xla_device()) + out = model(inp) + + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([out]) + hlo_text = ctx.hlo_json() + + # Strings to match in the lowering + bingo = { + "torch/_ops.py": False, + #"torch/nn/modules/linear.py": False, + #"torch/nn/modules/activation.py": False, + #"torch/nn/functional.py": False, + "Sequential[model]/Linear[0]": False, + "Sequential[model]/ReLU[1]": False, + "Sequential[model]/Linear[2]": False, + "Sequential[model]/Tanh[3]": False, + "aten__addmm": False, + "aten__relu": False, + "aten__tanh": False, + "aten__permute": False + } + + non_zero_metadata = False + + local_json = json.loads(hlo_text) + assert "computations" in local_json + for c in local_json["computations"]: + if "instructions" in c: + i = c["instructions"] + for op in i: + if 'metadata' in op: + meta = op["metadata"] + print(meta) + if len(meta) > 0: + non_zero_metadata = True + for km, vm in meta.items(): + for k in bingo.keys(): + if isinstance(vm, str) and k in vm: + bingo[k] = True + + assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?" + + for k, v in bingo.items(): + assert v, f"Keyword {k} was not found as expected in HLO metadata for simple test" + + print("All required metadata symbols matched") + + +if __name__ == '__main__': + test = unittest.main(exit=False) + if xu.getenv_as('METRICS_DEBUG', bool, defval=False): + print(met.metrics_report()) + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4758579bbb6a..fde4ee1c9d70 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -854,6 +854,13 @@ class PyLoweringContext { return result; } + std::string GetHloJsonText() { + const xla::HloModuleProto& proto = computation.proto(); + std::string result; + google::protobuf::util::MessageToJsonString(proto, &result); + return result; + } + private: LoweringContext lowering_ctx; xla::XlaComputation computation; @@ -895,6 +902,7 @@ void BuildLoweringContextSubmodule(py::module* m) { .def("build", &PyLoweringContext::Build) .def("hlo", &PyLoweringContext::GetHlo) .def("hlo_text", &PyLoweringContext::GetHloText) + .def("hlo_json", &PyLoweringContext::GetHloJsonText) .def("parameter_id_tensor_mapping", &PyLoweringContext::GetParameterIdTensorMapping) .def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId); @@ -1910,6 +1918,16 @@ void InitXlaModuleBindings(py::module m) { [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { return XLANativeFunctions::set_(self, source); }); + m.def("_set_xla_custom_op_name", + [](const at::Tensor& input, const std::string& op_name) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->SetCustomOpName(op_name); + }); + m.def("_get_xla_custom_op_name", + [](const at::Tensor& input) -> const std::string& { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + return xtensor->GetCustomOpName(); + }); m.def("_get_all_reduce_token", [](const std::string& device_str) -> const torch::lazy::Value& { auto device = GetDeviceOrCurrent(device_str); diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 82b746ab1813..b7cd2025bd34 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -230,4 +230,8 @@ void XlaNode::UpdateShardingHash() { } } +void XlaNode::SetCustomOpName(const std::string& op_name) { + custom_op_name_ = op_name; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index d0619ef5c987..1e4a0439e235 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -146,6 +146,9 @@ class XlaNode : public torch::lazy::Node { return unbounded_dynamic_dims_; } + void SetCustomOpName(const std::string& op_name); + const std::string& custom_op_name() const { return custom_op_name_; } + protected: std::unordered_set unbounded_dynamic_dims_; @@ -167,6 +170,8 @@ class XlaNode : public torch::lazy::Node { // Experimental sharding annotations attached to the IR node. std::vector> output_shardings_; + + std::string custom_op_name_; }; inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 404fa82ea7b1..622e2ef7dd96 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -38,7 +38,7 @@ class HloMetadataSetter { static bool ShouldPopulateXlaOpMetadata() { static bool op_metadata = runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false); - return op_metadata; + return FLAGS_torch_lazy_ir_debug || op_metadata; } static void PopulateXlaOpMetadata(LoweringContext* loctx, @@ -53,6 +53,13 @@ class HloMetadataSetter { metadata.set_op_type(op_type); const torch::lazy::MetaData& nmeta = node->metadata(); std::string op_name_prefix; + + const XlaNode* xla_node_cast = dynamic_cast(node); + + if (xla_node_cast != nullptr && !xla_node_cast->custom_op_name().empty()) { + op_name_prefix = xla_node_cast->custom_op_name(); + } + if (!nmeta.scope.empty()) { op_name_prefix = absl::StrCat(absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/"); @@ -61,13 +68,8 @@ class HloMetadataSetter { if (!nmeta.frame_info.empty()) { const torch::lazy::SourceLocation& frame = nmeta.frame_info.front(); - std::string::size_type pos = frame.file.find_last_of('/'); - if (pos == std::string::npos) { - pos = 0; - } else { - ++pos; - } - metadata.set_source_file(frame.function + "@" + frame.file.substr(pos)); + + metadata.set_source_file(frame.file); metadata.set_source_line(frame.line); } loctx->builder()->SetOpMetadata(std::move(metadata)); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 96465abf44c1..4a97aad68b77 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -896,4 +896,20 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { xla_node->MarkDynamicDimension(dim); } +void XLATensor::SetCustomOpName(const std::string& op_name) { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) { + xla_node->SetCustomOpName(op_name); + } +} + +const std::string& XLATensor::GetCustomOpName() const { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + if (xla_node != nullptr) { + return xla_node->custom_op_name(); + } else { + return ""; + } +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index f73aed5ce5fc..83db2e95df61 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -282,6 +282,10 @@ class XLATensor : public torch::lazy::LazyTensor { // Override to enable SPMD. void AssignIrValue(torch::lazy::Value ir_value) const final; + // Set custom op name on XlaNode + void SetCustomOpName(const std::string& op_name); + const std::string& GetCustomOpName() const; + private: XLATensor(const at::Tensor& tensor, const torch::lazy::BackendDevice& device); XLATensor(torch::lazy::BackendDataPtr handle,