diff --git a/test/test_hlo_metadata.py b/test/test_hlo_metadata.py index 2ad4292e06e..51a70191b31 100644 --- a/test/test_hlo_metadata.py +++ b/test/test_hlo_metadata.py @@ -4,14 +4,6 @@ import argparse import sys -parser = argparse.ArgumentParser(add_help=False) -parser.add_argument('--replicated', action='store_true') -parser.add_argument('--long_test', action='store_true') -parser.add_argument('--max_diff_count', type=int, default=25) -parser.add_argument('--verbosity', type=int, default=0) -FLAGS, leftovers = parser.parse_known_args() -sys.argv = [sys.argv[0]] + leftovers - # Normal imports section starts here. import torch import torch_xla @@ -25,13 +17,18 @@ class TestHloMetaData(unittest.TestCase): def setUp(self): + torch.manual_seed(42) + self.pre_test_tensor_type = torch.get_default_dtype() + self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug() + torch.set_default_tensor_type(torch.FloatTensor) + torch_xla._XLAC._set_ir_debug(True) super(TestHloMetaData, self).setUp() def tearDown(self): super(TestHloMetaData, self).tearDown() + torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug) def test_metadata(self): - torch_xla._XLAC._set_ir_debug(True) layer1 = torch.nn.Linear(4, 4) nl1 = torch.nn.ReLU() layer2 = torch.nn.Linear(4, 2) @@ -69,7 +66,6 @@ def test_metadata(self): for op in i: if 'metadata' in op: meta = op["metadata"] - print(meta) if len(meta) > 0: non_zero_metadata = True for km, vm in meta.items(): @@ -86,11 +82,7 @@ def test_metadata(self): if __name__ == '__main__': - torch.set_default_tensor_type('torch.FloatTensor') - torch.manual_seed(42) - torch_xla._XLAC._xla_set_use_full_mat_mul_precision( - use_full_mat_mul_precision=True) - test = unittest.main(verbosity=FLAGS.verbosity, exit=False) + test = unittest.main(exit=False) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(met.metrics_report()) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index d1ddc64fb76..3521b531961 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -36,6 +36,7 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/python_util.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -1899,6 +1900,12 @@ void InitXlaModuleBindings(py::module m) { m.def("_set_ir_debug", [](bool ir_debug) { FLAGS_torch_lazy_ir_debug = ir_debug; }); m.def("_get_ir_debug", []() { return FLAGS_torch_lazy_ir_debug; }); + m.def("_add_xla_internal_module", [](const std::string& module){ + InternalModuleRegistry::Instance()->AddInternalModule(module); + }); + m.def("_check_xla_internal_module", [](const std::string& module){ + return InternalModuleRegistry::Instance()->CheckModuleString(module); + }); m.def("_set_xla_handle_special_scalars", [](bool handle_special_scalars) { FLAGS_torch_lazy_handle_special_scalars = handle_special_scalars; }); diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index aae0f4f519c..574de221821 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -37,46 +37,33 @@ class HloMetadataSetter { } private: - static bool ShouldPopulateXlaOpMetadata() { - static bool op_metadata = - runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false); - // Allows us to turn on HLO lowering with a python binding - // e.g. in a profile context - return FLAGS_torch_lazy_ir_debug || op_metadata; - } - - static void PopulateXlaOpMetadata(LoweringContext* loctx, - const torch::lazy::Node* node) { - xla::OpMetadata metadata; - // NOTE: we apply some string manipulation as xprof backend utility - // for nesting/grouping traces depends on certain op name/type - // patterns for classification. - // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55 - std::string op_type = - absl::StrReplaceAll(node->op().ToString(), {{":", "_"}}); - metadata.set_op_type(op_type); - const torch::lazy::MetaData& nmeta = node->metadata(); + static std::string FetchDebugOpNamePrefix( + LoweringContext* loctx, + const torch::lazy::MetaData& nmeta, + const torch::lazy::Node* node) { ExtendedFrameInfo* efi = nullptr; + std::string class_prefix; if (FLAGS_torch_lazy_ir_debug) { - TF_VLOG(1) << "PopulateXlaOpMetadata: Debug lowering enabled"; - efi = static_cast(node->user_metadata()); + TF_VLOG(1) << "FetchDebugOpNamePrefix: Debug lowering enabled"; + efi = dynamic_cast(node->user_metadata()); if (efi != nullptr && efi->frames.size() != nmeta.frame_info.size()) { - LOG(WARNING) << "PopulateXlaOpMetadata: Extra frame information length " + LOG(WARNING) << "FetchDebugOpNamePrefix: Extra frame information length " "does not match source " << "location length as expected " << efi->frames.size() << " != " << nmeta.frame_info.size(); + return class_prefix; } } else { - TF_VLOG(1) << "PopulateXlaOpMetadata: Debug lowering is *not* enabled"; + TF_VLOG(1) << "FetchDebugOpNamePrefix: Debug lowering is *not* enabled"; + return class_prefix; } - // Add class information iff there are frames - std::string last_class_name; - std::string class_prefix; + // Add class information iff there is extended frame information std::stringstream class_ss; + std::string last_class_name; std::string variable_name; std::string last_variable_name; @@ -110,10 +97,32 @@ class HloMetadataSetter { last_class_name = it->class_name; } class_prefix = class_ss.str(); - TF_VLOG(1) << "PopulateXlaOpMetadata: Class prefix = '" << class_prefix + TF_VLOG(1) << "FetchDebugOpNamePrefix: Class prefix = '" << class_prefix << "'"; } + return class_prefix; + } + + static bool ShouldPopulateXlaOpMetadata() { + static bool op_metadata = + runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false); + // Allows us to turn on HLO lowering with a python binding + // e.g. in a profile context + return FLAGS_torch_lazy_ir_debug || op_metadata; + } + + static void PopulateXlaOpMetadata(LoweringContext* loctx, + const torch::lazy::Node* node) { + xla::OpMetadata metadata; + const torch::lazy::MetaData& nmeta = node->metadata(); + // NOTE: we apply some string manipulation as xprof backend utility + // for nesting/grouping traces depends on certain op name/type + // patterns for classification. + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55 + std::string op_type = + absl::StrReplaceAll(node->op().ToString(), {{":", "_"}}); + metadata.set_op_type(op_type); if (!nmeta.frame_info.empty()) { // Print callstack to debug int i = 0; @@ -130,10 +139,12 @@ class HloMetadataSetter { metadata.set_source_line(frame.line); } - std::string op_name_prefix = class_prefix; + // Set to empty string if there is no debug information attached to the node from lowering + std::string op_name_prefix = FetchDebugOpNamePrefix(loctx, nmeta, node); + if (!nmeta.scope.empty()) op_name_prefix = absl::StrCat( - absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/", class_prefix); + absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/", op_name_prefix); TF_VLOG(1) << "PopulateXlaOpMetadata: Op name prefix = '" << op_name_prefix << "'"; diff --git a/torch_xla/csrc/python_util.cpp b/torch_xla/csrc/python_util.cpp index e06fde46d70..1ae9501d18a 100644 --- a/torch_xla/csrc/python_util.cpp +++ b/torch_xla/csrc/python_util.cpp @@ -10,12 +10,33 @@ #include "absl/types/optional.h" #include "torch/csrc/jit/python/pybind.h" #include "torch_xla/csrc/ir_metadata.h" + +#ifdef VLOG_IS_ON +#undef VLOG_IS_ON +#endif #include "torch_xla/csrc/runtime/tf_logging.h" namespace py = pybind11; namespace torch_xla { +InternalModuleRegistry* InternalModuleRegistry::Instance() { + static InternalModuleRegistry s_Instance; + return &s_Instance; +} + +void InternalModuleRegistry::AddInternalModule( + const std::string& module) { + std::lock_guard guard(mutex_); + internal_modules_.insert(module); +} + +bool InternalModuleRegistry::CheckModuleString( + const std::string& module) { + std::lock_guard guard(mutex_); + return internal_modules_.find(module) != internal_modules_.end(); +} + std::string AddrToString(py::handle& obj) { std::stringstream ss; ss << std::hex << obj.ptr(); @@ -85,25 +106,21 @@ bool CheckIgnoredKey(const std::string& key) { key[len - 2] == '_' && key[len - 1] == '_'))); } -bool CheckNeuronInternal(PyFrameObject* frame) { - bool neuronx_call = false; - static std::string neuronx_string("torch_neuronx"); +bool CheckInternalModule(PyFrameObject* frame) { + bool internal_call = false; if (frame != nullptr || frame->f_globals != nullptr) { auto dict = py::reinterpret_borrow(frame->f_globals); if (dict.contains("__name__")) { std::string module_name = py::cast(dict["__name__"]); - TF_VLOG(1) << "CheckNeuronInternal: Module name = " + TF_VLOG(1) << "CheckInternalModule: Module name = " << py::cast(dict["__name__"]); - if (module_name.length() >= neuronx_string.length() && - module_name.find(neuronx_string) != std::string::npos) { - neuronx_call = true; - } + internal_call = InternalModuleRegistry::Instance()->CheckModuleString(module_name); } } - return neuronx_call; + return internal_call; } bool ReverseSearchBreadthFirst(py::object& container, py::object& obj_to_find, @@ -365,7 +382,7 @@ void GetClassNameAndObjFromFrame(PyFrameObject* frame, if (!found_name && frame->f_locals != nullptr) { TF_VLOG(1) << "GetClassNameAndObjFromFrame: ** LOOK FOR LOCALS **"; py::object locals = py::reinterpret_borrow(frame->f_locals); - if (!CheckNeuronInternal(frame)) + if (!CheckInternalModule(frame)) found_name = ReverseSearchBreadthFirst(locals, obj, sloc.obj_name); TF_VLOG(1) << "GetClassNameAndObjFromFrame: Tried local found = " << found_name; @@ -384,7 +401,7 @@ void GetClassNameAndObjFromFrame(PyFrameObject* frame, << frame_id << " **"; py::object locals = py::reinterpret_borrow(parent_frame->f_locals); - if (!CheckNeuronInternal(parent_frame)) + if (!CheckInternalModule(parent_frame)) found_name = ReverseSearchBreadthFirst(locals, obj, sloc.obj_name); TF_VLOG(1) << "GetClassNameAndObjFromFrame: Tried (parent) local found = " << found_name; diff --git a/torch_xla/csrc/python_util.h b/torch_xla/csrc/python_util.h index 61c75c5591c..10345fceb44 100644 --- a/torch_xla/csrc/python_util.h +++ b/torch_xla/csrc/python_util.h @@ -3,8 +3,24 @@ #include "torch/csrc/lazy/core/ir_metadata.h" #include "torch_xla/csrc/ir_metadata.h" +#include +#include + namespace torch_xla { std::shared_ptr GetExtendedFrameInfo(); +class InternalModuleRegistry { +public: + + static InternalModuleRegistry* Instance(); + void AddInternalModule(const std::string& module); + bool CheckModuleString(const std::string& module); + +private: + + std::mutex mutex_; + std::set internal_modules_; +}; + } // namespace torch_xla \ No newline at end of file