diff --git a/test/test_operations.py b/test/test_operations.py index 22e03c196b9..cc3a73c4580 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2640,12 +2640,14 @@ def test_api(self): result = a + b - ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) hlo = ctx.hlo() hlo_text = ctx.hlo_text() - self.assertTrue('opcode: "parameter"' in hlo_text) - self.assertTrue('opcode: "add"' in hlo_text) + self.assertIn('MyCustomName', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "parameter"', hlo_text) + self.assertIn('opcode: "add"', hlo_text) mapping = ctx.parameter_id_tensor_mapping() self.assertEqual(len(mapping), 2) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 959bbfe2dd8..f245efdb08a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -978,10 +978,14 @@ void BuildProfilerSubmodule(py::module* m) { class PyLoweringContext { public: - PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {} + PyLoweringContext() + : PyLoweringContext("PyLoweringContext", bridge::GetCurrentDevice()) {} - PyLoweringContext(torch::lazy::BackendDevice device) - : lowering_ctx("PyLoweringContext", device) {} + PyLoweringContext(const std::string& name) + : PyLoweringContext(name, bridge::GetCurrentDevice()) {} + + PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device) + : lowering_ctx(name, device) {} // Builds a HLO graph given a set of output tensors. void Build(std::vector tensors) { @@ -1188,7 +1192,7 @@ void BuildLoweringContextSubmodule(py::module* m) { py::class_> lowering_context_class(lowering, "LoweringContext", py::module_local()); - lowering_context_class.def(py::init<>()) + lowering_context_class.def(py::init()) .def("build", &PyLoweringContext::Build) .def("buildforiloop", &PyLoweringContext::BuildForiLoop) .def("hlo", &PyLoweringContext::GetHlo) diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 2e0ced89927..d8d30bb6743 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -477,7 +477,7 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: y_len = len(fake_output_y) fn_outputs = fake_output_carry + fake_output_y - fn_ctx = torch_xla._XLAC.lowering.LoweringContext() + fn_ctx = torch_xla._XLAC.lowering.LoweringContext("FnComputation") fn_ctx.set_name_string("fn_ctx") fn_ctx.build(list(fn_outputs)) fn_hlo = fn_ctx.hlo()