Skip to content

Commit

Permalink
Add python binding to ignore internal stack frames for name lookups. …
Browse files Browse the repository at this point in the history
…Address PR feedback (refactor)
  • Loading branch information
mrnikwaws committed Nov 1, 2023
1 parent 50ca8ef commit 7ed8482
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 55 deletions.
22 changes: 7 additions & 15 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
import argparse
import sys

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--replicated', action='store_true')
parser.add_argument('--long_test', action='store_true')
parser.add_argument('--max_diff_count', type=int, default=25)
parser.add_argument('--verbosity', type=int, default=0)
FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers

# Normal imports section starts here.
import torch
import torch_xla
Expand All @@ -25,13 +17,18 @@
class TestHloMetaData(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.pre_test_tensor_type = torch.get_default_dtype()
self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug()
torch.set_default_tensor_type(torch.FloatTensor)
torch_xla._XLAC._set_ir_debug(True)
super(TestHloMetaData, self).setUp()

def tearDown(self):
super(TestHloMetaData, self).tearDown()
torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug)

def test_metadata(self):
torch_xla._XLAC._set_ir_debug(True)
layer1 = torch.nn.Linear(4, 4)
nl1 = torch.nn.ReLU()
layer2 = torch.nn.Linear(4, 2)
Expand Down Expand Up @@ -69,7 +66,6 @@ def test_metadata(self):
for op in i:
if 'metadata' in op:
meta = op["metadata"]
print(meta)
if len(meta) > 0:
non_zero_metadata = True
for km, vm in meta.items():
Expand All @@ -86,11 +82,7 @@ def test_metadata(self):


if __name__ == '__main__':
torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
test = unittest.main(exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
sys.exit(0 if test.result.wasSuccessful() else 1)
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/python_util.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
Expand Down Expand Up @@ -1899,6 +1900,12 @@ void InitXlaModuleBindings(py::module m) {
m.def("_set_ir_debug",
[](bool ir_debug) { FLAGS_torch_lazy_ir_debug = ir_debug; });
m.def("_get_ir_debug", []() { return FLAGS_torch_lazy_ir_debug; });
m.def("_add_xla_internal_module", [](const std::string& module){
InternalModuleRegistry::Instance()->AddInternalModule(module);
});
m.def("_check_xla_internal_module", [](const std::string& module){
return InternalModuleRegistry::Instance()->CheckModuleString(module);
});
m.def("_set_xla_handle_special_scalars", [](bool handle_special_scalars) {
FLAGS_torch_lazy_handle_special_scalars = handle_special_scalars;
});
Expand Down
69 changes: 40 additions & 29 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,33 @@ class HloMetadataSetter {
}

private:
static bool ShouldPopulateXlaOpMetadata() {
static bool op_metadata =
runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false);
// Allows us to turn on HLO lowering with a python binding
// e.g. in a profile context
return FLAGS_torch_lazy_ir_debug || op_metadata;
}

static void PopulateXlaOpMetadata(LoweringContext* loctx,
const torch::lazy::Node* node) {
xla::OpMetadata metadata;
// NOTE: we apply some string manipulation as xprof backend utility
// for nesting/grouping traces depends on certain op name/type
// patterns for classification.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55
std::string op_type =
absl::StrReplaceAll(node->op().ToString(), {{":", "_"}});
metadata.set_op_type(op_type);
const torch::lazy::MetaData& nmeta = node->metadata();

static std::string FetchDebugOpNamePrefix(
LoweringContext* loctx,
const torch::lazy::MetaData& nmeta,
const torch::lazy::Node* node) {
ExtendedFrameInfo* efi = nullptr;
std::string class_prefix;

if (FLAGS_torch_lazy_ir_debug) {
TF_VLOG(1) << "PopulateXlaOpMetadata: Debug lowering enabled";
efi = static_cast<ExtendedFrameInfo*>(node->user_metadata());
TF_VLOG(1) << "FetchDebugOpNamePrefix: Debug lowering enabled";
efi = dynamic_cast<ExtendedFrameInfo*>(node->user_metadata());

if (efi != nullptr && efi->frames.size() != nmeta.frame_info.size()) {
LOG(WARNING) << "PopulateXlaOpMetadata: Extra frame information length "
LOG(WARNING) << "FetchDebugOpNamePrefix: Extra frame information length "
"does not match source "
<< "location length as expected " << efi->frames.size()
<< " != " << nmeta.frame_info.size();
return class_prefix;
}
} else {
TF_VLOG(1) << "PopulateXlaOpMetadata: Debug lowering is *not* enabled";
TF_VLOG(1) << "FetchDebugOpNamePrefix: Debug lowering is *not* enabled";
return class_prefix;
}

// Add class information iff there are frames
std::string last_class_name;
std::string class_prefix;
// Add class information iff there is extended frame information
std::stringstream class_ss;
std::string last_class_name;
std::string variable_name;
std::string last_variable_name;

Expand Down Expand Up @@ -110,10 +97,32 @@ class HloMetadataSetter {
last_class_name = it->class_name;
}
class_prefix = class_ss.str();
TF_VLOG(1) << "PopulateXlaOpMetadata: Class prefix = '" << class_prefix
TF_VLOG(1) << "FetchDebugOpNamePrefix: Class prefix = '" << class_prefix
<< "'";
}

return class_prefix;
}

static bool ShouldPopulateXlaOpMetadata() {
static bool op_metadata =
runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false);
// Allows us to turn on HLO lowering with a python binding
// e.g. in a profile context
return FLAGS_torch_lazy_ir_debug || op_metadata;
}

static void PopulateXlaOpMetadata(LoweringContext* loctx,
const torch::lazy::Node* node) {
xla::OpMetadata metadata;
const torch::lazy::MetaData& nmeta = node->metadata();
// NOTE: we apply some string manipulation as xprof backend utility
// for nesting/grouping traces depends on certain op name/type
// patterns for classification.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55
std::string op_type =
absl::StrReplaceAll(node->op().ToString(), {{":", "_"}});
metadata.set_op_type(op_type);
if (!nmeta.frame_info.empty()) {
// Print callstack to debug
int i = 0;
Expand All @@ -130,10 +139,12 @@ class HloMetadataSetter {
metadata.set_source_line(frame.line);
}

std::string op_name_prefix = class_prefix;
// Set to empty string if there is no debug information attached to the node from lowering
std::string op_name_prefix = FetchDebugOpNamePrefix(loctx, nmeta, node);

if (!nmeta.scope.empty())
op_name_prefix = absl::StrCat(
absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/", class_prefix);
absl::StrReplaceAll(nmeta.scope, {{":", "_"}}), "/", op_name_prefix);

TF_VLOG(1) << "PopulateXlaOpMetadata: Op name prefix = '" << op_name_prefix
<< "'";
Expand Down
39 changes: 28 additions & 11 deletions torch_xla/csrc/python_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,33 @@
#include "absl/types/optional.h"
#include "torch/csrc/jit/python/pybind.h"
#include "torch_xla/csrc/ir_metadata.h"

#ifdef VLOG_IS_ON
#undef VLOG_IS_ON
#endif
#include "torch_xla/csrc/runtime/tf_logging.h"

namespace py = pybind11;

namespace torch_xla {

InternalModuleRegistry* InternalModuleRegistry::Instance() {
static InternalModuleRegistry s_Instance;
return &s_Instance;
}

void InternalModuleRegistry::AddInternalModule(
const std::string& module) {
std::lock_guard<std::mutex> guard(mutex_);
internal_modules_.insert(module);
}

bool InternalModuleRegistry::CheckModuleString(
const std::string& module) {
std::lock_guard<std::mutex> guard(mutex_);
return internal_modules_.find(module) != internal_modules_.end();
}

std::string AddrToString(py::handle& obj) {
std::stringstream ss;
ss << std::hex << obj.ptr();
Expand Down Expand Up @@ -85,25 +106,21 @@ bool CheckIgnoredKey(const std::string& key) {
key[len - 2] == '_' && key[len - 1] == '_')));
}

bool CheckNeuronInternal(PyFrameObject* frame) {
bool neuronx_call = false;
static std::string neuronx_string("torch_neuronx");
bool CheckInternalModule(PyFrameObject* frame) {
bool internal_call = false;

if (frame != nullptr || frame->f_globals != nullptr) {
auto dict = py::reinterpret_borrow<py::dict>(frame->f_globals);

if (dict.contains("__name__")) {
std::string module_name = py::cast<std::string>(dict["__name__"]);
TF_VLOG(1) << "CheckNeuronInternal: Module name = "
TF_VLOG(1) << "CheckInternalModule: Module name = "
<< py::cast<std::string>(dict["__name__"]);
if (module_name.length() >= neuronx_string.length() &&
module_name.find(neuronx_string) != std::string::npos) {
neuronx_call = true;
}
internal_call = InternalModuleRegistry::Instance()->CheckModuleString(module_name);
}
}

return neuronx_call;
return internal_call;
}

bool ReverseSearchBreadthFirst(py::object& container, py::object& obj_to_find,
Expand Down Expand Up @@ -365,7 +382,7 @@ void GetClassNameAndObjFromFrame(PyFrameObject* frame,
if (!found_name && frame->f_locals != nullptr) {
TF_VLOG(1) << "GetClassNameAndObjFromFrame: ** LOOK FOR LOCALS **";
py::object locals = py::reinterpret_borrow<py::object>(frame->f_locals);
if (!CheckNeuronInternal(frame))
if (!CheckInternalModule(frame))
found_name = ReverseSearchBreadthFirst(locals, obj, sloc.obj_name);
TF_VLOG(1) << "GetClassNameAndObjFromFrame: Tried local found = "
<< found_name;
Expand All @@ -384,7 +401,7 @@ void GetClassNameAndObjFromFrame(PyFrameObject* frame,
<< frame_id << " **";
py::object locals =
py::reinterpret_borrow<py::object>(parent_frame->f_locals);
if (!CheckNeuronInternal(parent_frame))
if (!CheckInternalModule(parent_frame))
found_name = ReverseSearchBreadthFirst(locals, obj, sloc.obj_name);
TF_VLOG(1) << "GetClassNameAndObjFromFrame: Tried (parent) local found = "
<< found_name;
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/python_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,24 @@
#include "torch/csrc/lazy/core/ir_metadata.h"
#include "torch_xla/csrc/ir_metadata.h"

#include <mutex>
#include <set>

namespace torch_xla {

std::shared_ptr<torch::lazy::UserMetaData> GetExtendedFrameInfo();

class InternalModuleRegistry {
public:

static InternalModuleRegistry* Instance();
void AddInternalModule(const std::string& module);
bool CheckModuleString(const std::string& module);

private:

std::mutex mutex_;
std::set<std::string> internal_modules_;
};

} // namespace torch_xla

0 comments on commit 7ed8482

Please sign in to comment.