Skip to content

Commit

Permalink
Give the lowered fn computation a more meaningful name
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Nov 22, 2024
1 parent 7e4cf9a commit 98222b2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
8 changes: 5 additions & 3 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,12 +2640,14 @@ def test_api(self):

result = a + b

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
hlo = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertTrue('opcode: "parameter"' in hlo_text)
self.assertTrue('opcode: "add"' in hlo_text)
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "add"', hlo_text)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)

Expand Down
13 changes: 9 additions & 4 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,14 @@ void BuildProfilerSubmodule(py::module* m) {

class PyLoweringContext {
public:
PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {}
PyLoweringContext()
: PyLoweringContext("PyLoweringContext", bridge::GetCurrentDevice()) {}

PyLoweringContext(torch::lazy::BackendDevice device)
: lowering_ctx("PyLoweringContext", device) {}
PyLoweringContext(const std::string& name)
: PyLoweringContext(name, bridge::GetCurrentDevice()) {}

PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device)
: lowering_ctx(name, device) {}

// Builds a HLO graph given a set of output tensors.
void Build(std::vector<at::Tensor> tensors) {
Expand Down Expand Up @@ -1188,7 +1192,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
py::class_<PyLoweringContext, std::unique_ptr<PyLoweringContext>>
lowering_context_class(lowering, "LoweringContext", py::module_local());

lowering_context_class.def(py::init<>())
lowering_context_class.def(py::init())
.def(py::init<std::string>())
.def("build", &PyLoweringContext::Build)
.def("buildforiloop", &PyLoweringContext::BuildForiLoop)
.def("hlo", &PyLoweringContext::GetHlo)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
y_len = len(fake_output_y)
fn_outputs = fake_output_carry + fake_output_y

fn_ctx = torch_xla._XLAC.lowering.LoweringContext()
fn_ctx = torch_xla._XLAC.lowering.LoweringContext("FnComputation")
fn_ctx.set_name_string("fn_ctx")
fn_ctx.build(list(fn_outputs))
fn_hlo = fn_ctx.hlo()
Expand Down

0 comments on commit 98222b2

Please sign in to comment.