Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[test] Againwhile loop lowering with simplecalculation dispatch in torch #6563

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
243 commits
Select commit Hold shift + click to select a range
0a2cc00
Create fori_loop.py
ManfeiBai Feb 13, 2024
c4b4384
Update fori_loop.py
ManfeiBai Feb 13, 2024
8ac0bb5
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
0a0f8c6
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
c62ad7d
Update init_python_bindings.cpp
ManfeiBai Feb 13, 2024
f0dd53e
Update fori_loop.py
ManfeiBai Feb 13, 2024
dcf65fc
Update fori_loop.py
ManfeiBai Feb 13, 2024
a1ed583
Create test_fori_loop.py
ManfeiBai Feb 13, 2024
cea9062
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
7e1ffcb
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
6694c09
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
caeeb50
Update test_fori_loop.py
ManfeiBai Feb 13, 2024
6a4ad4f
use code from xb
ManfeiBai Feb 13, 2024
76a1cfa
only xb test
ManfeiBai Feb 13, 2024
694e797
check original version has used python dispatch or not
ManfeiBai Feb 13, 2024
fc4c8bb
check original version has used python dispatch or not again
ManfeiBai Feb 13, 2024
cf6aae9
add test script for xla
ManfeiBai Feb 13, 2024
04a8eff
add test script for xla again
ManfeiBai Feb 13, 2024
9c4ba74
test with while_loop_dense
ManfeiBai Feb 13, 2024
e377b77
change dispatchkey
ManfeiBai Feb 13, 2024
27d66f8
re dispatch
ManfeiBai Feb 13, 2024
82612f1
re dispatch again
ManfeiBai Feb 13, 2024
0de6bae
check again
ManfeiBai Feb 13, 2024
0977c54
check type
ManfeiBai Feb 13, 2024
e76484f
checkpoint to show body/cond hlo
ManfeiBai Feb 14, 2024
f5dea5d
correct xb
ManfeiBai Feb 14, 2024
a189666
add example code script
ManfeiBai Feb 14, 2024
12e6886
add example code script again
ManfeiBai Feb 14, 2024
202f1ef
only test tpu
ManfeiBai Feb 14, 2024
d6fd934
only test tpu
ManfeiBai Feb 14, 2024
8b6df8e
add result value check
ManfeiBai Feb 14, 2024
2db33ab
add result value check
ManfeiBai Feb 14, 2024
6d3f3a4
clean code
ManfeiBai Feb 14, 2024
700ed2a
try torchxla code
ManfeiBai Feb 14, 2024
0f05d76
try torchxla code again
ManfeiBai Feb 14, 2024
d7e54c3
try torchxla code again again
ManfeiBai Feb 14, 2024
1b11308
try torchxla code again again again
ManfeiBai Feb 14, 2024
fea71a6
add test on CPU/GPU tests
ManfeiBai Feb 14, 2024
9a42faf
add test on TPU test trigger's
ManfeiBai Feb 14, 2024
0d21e47
add test in TPU CI workflow
ManfeiBai Feb 14, 2024
45c4433
Merge branch 'master' into while_loop-lowering-with-simplecalculation…
ManfeiBai Feb 14, 2024
5050eb4
placeholder for xlacomputation
ManfeiBai Feb 17, 2024
92b69f4
add test
ManfeiBai Feb 17, 2024
c4ed61a
modif test code
ManfeiBai Feb 17, 2024
1c8395e
modify fori_loop
ManfeiBai Feb 17, 2024
c9e7b5f
format
ManfeiBai Feb 19, 2024
02f738f
format
ManfeiBai Feb 19, 2024
d689acc
format
ManfeiBai Feb 19, 2024
122d6e5
format
ManfeiBai Feb 19, 2024
e95ddc3
test log
ManfeiBai Feb 21, 2024
9cacae5
test it
ManfeiBai Feb 21, 2024
3a6de14
test it again
ManfeiBai Feb 21, 2024
e1ec664
test it again again
ManfeiBai Feb 21, 2024
e2bfae4
test it again again again
ManfeiBai Feb 21, 2024
52362aa
test it again again again agian
ManfeiBai Feb 21, 2024
27643f6
test it again again again agian agian
ManfeiBai Feb 21, 2024
0c68bd3
test it again again again agian agian again
ManfeiBai Feb 21, 2024
d6b95bb
test it again again again agian agian again again
ManfeiBai Feb 21, 2024
16ed0b3
test it again again again agian agian again again again again
ManfeiBai Feb 21, 2024
baeee87
test it again again again agian agian again again again again aga
ManfeiBai Feb 21, 2024
da3e56e
test it again again again agian agian again again again again aga
ManfeiBai Feb 21, 2024
9e45118
test it again again again agian agian again again again again aga
ManfeiBai Feb 21, 2024
9be8c97
unwarp
ManfeiBai Feb 21, 2024
c2d2104
warp
ManfeiBai Feb 21, 2024
9fa20a4
warp
ManfeiBai Feb 21, 2024
6d05089
unwarp
ManfeiBai Feb 21, 2024
cfaec4a
unwarp
ManfeiBai Feb 21, 2024
7575801
name
ManfeiBai Feb 21, 2024
fc06368
name
ManfeiBai Feb 21, 2024
1b92b09
name
ManfeiBai Feb 21, 2024
e3348e4
unwrap
ManfeiBai Feb 21, 2024
383f3fb
unwrap
ManfeiBai Feb 21, 2024
53d3039
unwrap
ManfeiBai Feb 21, 2024
724f0bf
unwrap
ManfeiBai Feb 21, 2024
55b645f
unwrap
ManfeiBai Feb 21, 2024
5cb01c0
unwrap
ManfeiBai Feb 21, 2024
601e406
unwrap
ManfeiBai Feb 21, 2024
a27852e
unwrap
ManfeiBai Feb 21, 2024
a306b5b
unwrap
ManfeiBai Feb 21, 2024
c6f6251
unwrap
ManfeiBai Feb 21, 2024
3ace502
unwrap
ManfeiBai Feb 21, 2024
8e63c8f
unwrap
ManfeiBai Feb 21, 2024
0fba660
unwrap
ManfeiBai Feb 21, 2024
30acbf7
unwrap
ManfeiBai Feb 21, 2024
2a335ab
unwrap
ManfeiBai Feb 21, 2024
116a35e
unwrap
ManfeiBai Feb 21, 2024
e32170c
unwrap
ManfeiBai Feb 21, 2024
c5fe5e7
unwrap
ManfeiBai Feb 21, 2024
2d6c82b
unwrap
ManfeiBai Feb 21, 2024
8b858b6
unwrap
ManfeiBai Feb 21, 2024
13ed2ac
unwrap
ManfeiBai Feb 21, 2024
1a205cd
unwrap
ManfeiBai Feb 21, 2024
0eaec23
unwrap
ManfeiBai Feb 21, 2024
438fc48
unwrap
ManfeiBai Feb 21, 2024
fe880dd
unwrap
ManfeiBai Feb 21, 2024
b5660d4
unwrap
ManfeiBai Feb 21, 2024
dd5f2f5
unwrap
ManfeiBai Feb 21, 2024
3696a5f
unwrap
ManfeiBai Feb 21, 2024
86aae55
unwrap
ManfeiBai Feb 21, 2024
39b63f7
unwrap
ManfeiBai Feb 21, 2024
b39e9db
unwrap
ManfeiBai Feb 21, 2024
ed8a594
unwrap
ManfeiBai Feb 21, 2024
b856146
unwrap
ManfeiBai Feb 21, 2024
2084fbb
unwrap
ManfeiBai Feb 21, 2024
b3bd36a
unwrap
ManfeiBai Feb 21, 2024
dbc9980
return
ManfeiBai Feb 21, 2024
ed54291
return
ManfeiBai Feb 21, 2024
f5bed17
return
ManfeiBai Feb 21, 2024
381a530
return
ManfeiBai Feb 21, 2024
2f85648
return
ManfeiBai Feb 22, 2024
3c068a3
return
ManfeiBai Feb 22, 2024
2e45864
return
ManfeiBai Feb 22, 2024
608f8ff
return
ManfeiBai Feb 22, 2024
ecf04fd
return
ManfeiBai Feb 22, 2024
b28aef4
return
ManfeiBai Feb 22, 2024
5b6131b
return
ManfeiBai Feb 22, 2024
85dbd62
return
ManfeiBai Feb 22, 2024
796560d
return
ManfeiBai Feb 22, 2024
3bac1bf
return
ManfeiBai Feb 22, 2024
861168c
return
ManfeiBai Feb 22, 2024
395ad79
return
ManfeiBai Feb 22, 2024
fc8dadf
add
ManfeiBai Feb 22, 2024
eb89764
add
ManfeiBai Feb 22, 2024
e7c8d31
add
ManfeiBai Feb 22, 2024
bf98aa3
add
ManfeiBai Feb 22, 2024
0fe918f
add
ManfeiBai Feb 22, 2024
97194ba
add
ManfeiBai Feb 22, 2024
6951878
add
ManfeiBai Feb 22, 2024
a46d794
add
ManfeiBai Feb 22, 2024
b65620d
add
ManfeiBai Feb 22, 2024
fb7aadf
add
ManfeiBai Feb 22, 2024
f96a1dc
add
ManfeiBai Feb 22, 2024
cf0d6d0
add
ManfeiBai Feb 22, 2024
df0b153
add
ManfeiBai Feb 22, 2024
560ddb8
add
ManfeiBai Feb 22, 2024
f91f53b
add
ManfeiBai Feb 22, 2024
3cf497f
add
ManfeiBai Feb 22, 2024
64f736f
add
ManfeiBai Feb 22, 2024
a5de388
add
ManfeiBai Feb 22, 2024
db70e33
add
ManfeiBai Feb 22, 2024
591635c
add
ManfeiBai Feb 22, 2024
47dec86
add
ManfeiBai Feb 22, 2024
adbba50
add
ManfeiBai Feb 22, 2024
c1b20f6
add
ManfeiBai Feb 22, 2024
83ee3a1
add
ManfeiBai Feb 22, 2024
d82fb7e
add
ManfeiBai Feb 22, 2024
9a47a12
add
ManfeiBai Feb 22, 2024
a52ab4e
add
ManfeiBai Feb 22, 2024
c169eae
add
ManfeiBai Feb 22, 2024
a960bf9
add
ManfeiBai Feb 22, 2024
12803aa
add
ManfeiBai Feb 22, 2024
7ce860e
add
ManfeiBai Feb 22, 2024
2e40011
add
ManfeiBai Feb 22, 2024
4494e20
add
ManfeiBai Feb 22, 2024
1f54281
add
ManfeiBai Feb 22, 2024
136f42b
add
ManfeiBai Feb 22, 2024
53b2eb3
add
ManfeiBai Feb 22, 2024
d3012bb
add
ManfeiBai Feb 22, 2024
284b788
add
ManfeiBai Feb 22, 2024
0d2d1d8
add
ManfeiBai Feb 22, 2024
ca791bc
add
ManfeiBai Feb 22, 2024
effe4b9
add
ManfeiBai Feb 22, 2024
1676eb1
add
ManfeiBai Feb 22, 2024
0bd7c14
add
ManfeiBai Feb 22, 2024
f5d458f
add
ManfeiBai Feb 22, 2024
1c6b57a
add
ManfeiBai Feb 22, 2024
16ef994
add
ManfeiBai Feb 22, 2024
711de52
add
ManfeiBai Feb 22, 2024
476ceb8
add
ManfeiBai Feb 22, 2024
1ec07e9
add
ManfeiBai Feb 22, 2024
6c0c5b8
add
ManfeiBai Feb 22, 2024
c1189e6
add
ManfeiBai Feb 22, 2024
7abcc2b
add
ManfeiBai Feb 22, 2024
296a774
add
ManfeiBai Feb 22, 2024
a6a7263
add
ManfeiBai Feb 22, 2024
c924e25
add
ManfeiBai Feb 22, 2024
3ef9dea
add
ManfeiBai Feb 22, 2024
0df806d
add
ManfeiBai Feb 22, 2024
b4ee102
add
ManfeiBai Feb 22, 2024
48dee94
add
ManfeiBai Feb 22, 2024
bdab0b0
add
ManfeiBai Feb 22, 2024
1a2cc20
add
ManfeiBai Feb 22, 2024
bbecee3
add
ManfeiBai Feb 22, 2024
59312ec
add
ManfeiBai Feb 22, 2024
e70e41b
add
ManfeiBai Feb 22, 2024
9b76908
add
ManfeiBai Feb 22, 2024
e70d42e
add
ManfeiBai Feb 22, 2024
69b1375
add
ManfeiBai Feb 22, 2024
9a88b9c
add
ManfeiBai Feb 22, 2024
73ae176
add
ManfeiBai Feb 22, 2024
c0892bc
add
ManfeiBai Feb 22, 2024
7cf620a
add
ManfeiBai Feb 22, 2024
b622881
add
ManfeiBai Feb 22, 2024
0fdbd74
add
ManfeiBai Feb 22, 2024
1469801
add
ManfeiBai Feb 22, 2024
f35b442
add
ManfeiBai Feb 22, 2024
903311d
add
ManfeiBai Feb 22, 2024
981a3d9
add
ManfeiBai Feb 22, 2024
3a2792f
add
ManfeiBai Feb 23, 2024
962bd2e
add
ManfeiBai Feb 23, 2024
c2c9704
add
ManfeiBai Feb 23, 2024
6b1c90b
add
ManfeiBai Feb 23, 2024
76049fe
add
ManfeiBai Feb 23, 2024
294547c
add
ManfeiBai Feb 23, 2024
83f2e31
add
ManfeiBai Feb 23, 2024
41d38d7
add
ManfeiBai Feb 23, 2024
40badf6
add
ManfeiBai Feb 23, 2024
d1ce7fe
add
ManfeiBai Feb 23, 2024
7100144
add
ManfeiBai Feb 23, 2024
285c33e
add
ManfeiBai Feb 23, 2024
3b3ae8c
add
ManfeiBai Feb 23, 2024
b2b4e1c
add
ManfeiBai Feb 23, 2024
9b5c218
add
ManfeiBai Feb 23, 2024
7c90e76
add
ManfeiBai Feb 23, 2024
48a9b16
add
ManfeiBai Feb 23, 2024
484c631
add
ManfeiBai Feb 23, 2024
936cd4f
add
ManfeiBai Feb 23, 2024
cd105ce
add
ManfeiBai Feb 23, 2024
f06f405
add
ManfeiBai Feb 23, 2024
124b4f6
add
ManfeiBai Feb 23, 2024
302e666
add
ManfeiBai Feb 23, 2024
564b4c0
add
ManfeiBai Feb 23, 2024
691cc55
add
ManfeiBai Feb 23, 2024
5e02b6c
add
ManfeiBai Feb 23, 2024
e7a9892
add
ManfeiBai Feb 23, 2024
571b471
add
ManfeiBai Feb 23, 2024
2b6c60b
add
ManfeiBai Feb 23, 2024
4f01799
add
ManfeiBai Feb 23, 2024
d2b3d08
add
ManfeiBai Feb 23, 2024
4c1bdde
add
ManfeiBai Feb 23, 2024
7d6bde7
add
ManfeiBai Feb 23, 2024
56c019a
add
ManfeiBai Feb 23, 2024
a079874
add
ManfeiBai Feb 23, 2024
0a2d957
add
ManfeiBai Feb 23, 2024
3b6df5f
add
ManfeiBai Feb 23, 2024
3eb1c9b
add
ManfeiBai Feb 23, 2024
f8052e1
add
ManfeiBai Feb 23, 2024
0d5bf3d
add
ManfeiBai Feb 23, 2024
efe4464
add
ManfeiBai Feb 23, 2024
5759b57
add
ManfeiBai Feb 23, 2024
dec38c9
add
ManfeiBai Feb 23, 2024
9da0a64
torch sub
ManfeiBai Feb 23, 2024
57f76f2
update code
ManfeiBai Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
python3 -u test/test_autocast.py
python3 -u test/dynamo/test_dynamo.py
python3 -u test/spmd/test_spmd_debugging.py
python3 -u test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb


def _fake_while_loop(cond_fn, body_fn, operands):
# print("fake func operands: ", operands)
# print("fake func *operands: ", *operands)
# while cond_fn(*operands):
# operands = body_fn(*operands)
# return operands
while cond_fn(operands):
operands = body_fn(operands)
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu(self):

device = xm.xla_device()
# ten = torch.ones(1, dtype=torch.int32, device=device)
# ten[0] = 10

def cond_fn(x): # x = (xi,)
ten = torch.ones(1, dtype=torch.int32, device=device)
# ten = torch.tensor([5], dtype=torch.int32, device=device)
return x[0] >= ten[0] # ==x[0] # torch.equal(x[0], ten) # x[0] <= ten # 30

def body_fn(x): # x = (xi,)
# onei = torch.tensor(10, dtype=torch.int32, device=device)
return (torch.sub(x[0], 1),) # onei,)

# device = xm.xla_device()
# xi = torch.ones(1, dtype=torch.int32, device=device)
xi = torch.tensor([5], dtype=torch.int32, device=device)
# xi[0] = 5
# xi.to(device)
# yi = torch.ones(1, dtype=torch.int32, device=device)
# yi = torch.tensor([1], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (xi,))
expected = _fake_while_loop(cond_fn, body_fn, xi)
self.assertEqual(expected, res)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ spec:
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dtypes.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
python3 /src/pytorch/xla/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
volumeMounts:
- mountPath: /dev/shm
name: dshm
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ def mkop(name, ops, **kwargs):
builder = kwargs.get('builder', None)
if builder is None:
assert ops
# if not isinstance(ops, (list, tuple)):
# builder = torch_xla._XLAC._xla_op_builder(ops)
# else:
builder = torch_xla._XLAC._xla_op_builder(ops[0])
return Op(torch_xla._XLAC._xla_op_create(builder, name, ops, kwargs))

Expand Down
2 changes: 2 additions & 0 deletions torch_xla/core/xla_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __call__(self, *args, **kwargs):
self._computations[key] = computation
if xu.getenv_as('XLA_OP_PRINT_COMPUTATIONS', bool, False):
print(xb.get_computation_hlo(computation), file=sys.stderr)
print("777777777 args: ", args)
print("777777777 type args: ", type(args))
result = torch_xla._XLAC._xla_user_computation(self._opname, args,
computation)
return result[0] if len(result) == 1 else result
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
/*param_index=*/xla::ShapeIndex({input_index}));
}

// xla::XlaOp a = xla::GetTupleElement(orig_result, 0);

return builder.Build(orig_result);
}

Expand Down
62 changes: 61 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <thread>
#include <unordered_map>
#include <vector>
#include <iostream>

#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -671,7 +672,17 @@ py::object XlaNms(const at::Tensor& boxes, const at::Tensor& scores,
std::vector<at::Tensor> XlaUserComputation(
const std::string& opname, const std::vector<at::Tensor>& inputs,
runtime::ComputationClient::ComputationPtr computation) {
std::cout << " !!!$$$: " << std::endl;
for (int i = 0; i < inputs.size(); i++) {
std::cout << inputs[i] << "; " << std::endl;
std::cout << inputs[i].type() << "; type !!!" << std::endl;
// inputs[i] = (inputs[i]);
}
std::vector<XLATensorPtr> xinputs = GetXlaTensors(inputs, /*want_all=*/true);
// std::cout << " !!!$$$###: " << std::endl;
// for (int i = 0; i < xinputs.size(); i++) {
// std::cout << DumpUtil::ToText(xinputs[i]->CurrentIrValue().node.get()) << "; ";
// }
std::vector<XLATensorPtr> xresults =
tensor_methods::user_computation(opname, xinputs, std::move(computation));
std::vector<at::Tensor> results;
Expand All @@ -685,7 +696,14 @@ std::vector<at::Tensor> XlaUserComputation(

runtime::ComputationClient::ComputationPtr CreateComputation(
const std::string& name, xla::XlaOp root) {
std::cout << "w's build func name: " << name << std::endl;
// std::cout << "w's build builder name: " << root.builder().name_ << std::endl;
// https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/client/xla_builder.cc#L710
xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root));
// std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
// xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
// computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
// computation, program_shape.parameters(), input_output_alias_pair));
return std::make_shared<runtime::ComputationClient::Computation>(
name, std::move(computation));
}
Expand Down Expand Up @@ -876,13 +894,19 @@ void BuildProfilerSubmodule(py::module* m) {

class PyLoweringContext {
public:
PyLoweringContext(const std::string& name) : PyLoweringContext(name, bridge::GetCurrentDevice()) {}

PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device)
: lowering_ctx(name, device) {}

PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {}

PyLoweringContext(torch::lazy::BackendDevice device)
: lowering_ctx("PyLoweringContext", device) {}

// Builds a HLO graph given a set of output tensors.
void Build(std::vector<at::Tensor> tensors) {
// std::cout<< "let's see how many timed this was called? !!!" << GetNameString() << std::endl;
// Get the backing XLA tensors from the output torch tensor handles
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
Expand All @@ -894,13 +918,31 @@ class PyLoweringContext {
ir_values.push_back(value);
}

// // check computation name
// XLA_ERROR() << computation.proto().name();

// Lower the graph using the output IR values
for (auto& ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
// if (computation.proto().name()=='condctx') {
// xla::XlaOp a = xla::GetTupleElement(root, 0); // they are not tupled here
// }
lowering_ctx.AddResult(root);
}
computation = ConsumeValue(lowering_ctx.BuildXla());

std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
bool should_wrap_parameter = (program_shape.parameters_size() >= 2); // true;
if (should_wrap_parameter) {
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
computation, program_shape.parameters(), input_output_alias_pair));
}

// // unwrap (pred[])
// xla::XlaBuilder builder(computation.proto().name());
// xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params);
}

// Get a mapping from the HLO input parameters to the backing Tensor values.
Expand Down Expand Up @@ -983,6 +1025,22 @@ class PyLoweringContext {
return result;
}

void SetNameString(const std::string& name) {
lowering_ctx.setnamestring(name);
}

std::string GetNameString() {
return lowering_ctx.getnamestring();
}

// LoweringContext GetLoweringCtx() {
// return lowering_ctx;
// }

// LoweringContext SetLoweringCtxName(const std::string name) {
// lowering_ctx.builder().name_ = name;
// }

private:
LoweringContext lowering_ctx;
xla::XlaComputation computation;
Expand Down Expand Up @@ -1027,7 +1085,9 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("setnamestring", &PyLoweringContext::SetNameString)
.def("getnamestring", &PyLoweringContext::GetNameString);
}

void InitXlaModuleBindings(py::module m) {
Expand Down
27 changes: 26 additions & 1 deletion torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,34 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) {

xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
xla::StatusOr<xla::XlaComputation> xla;
if (!root_tuple_.empty()) {
// if (builder_.name() == 'bodyctx') {
// XLA_ERROR() << builder_.name();
// }
std::cout << "???" << getnamestring();
if (!root_tuple_.empty() & (root_tuple_.size()>1)) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
// xla::XlaOp a = xla::GetTupleElement(root, 0);
xla = builder()->Build(root);
} else if (!root_tuple_.empty() & (root_tuple_.size()==1)) {
// xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
// xla::XlaOp a = xla::GetTupleElement(root, 0);
// const xla::Shape& root_shape = ShapeHelper::ShapeOfXlaOp(root_tuple_.at(0));
// xla::XlaBuilder cb("predone");
// xla::Shape xla_scalar_shape = xla::ShapeUtil::MakeShape(element_type, {});
// xla::XlaOp p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
// Below are to untuple the parameters
// xla = builder()->Build(root_tuple_.at(0)); // root);
const std::string condctx = "condctx";
const std::string bodyctx = "bodyctx";
// std::cout << "???" << builder()->name();
const std::string currentname = getnamestring();
if ((currentname == condctx) or (currentname == bodyctx)) { // == "condctx") {
xla = builder()->Build(root_tuple_.at(0)); // root);
}
else {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
}
} else {
xla = builder()->Build();
}
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/ir_util.h>

#include <iostream>
#include <memory>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -34,6 +35,10 @@ class LoweringContext : public torch::lazy::LoweringContext {

xla::XlaBuilder* builder() { return &builder_; }

void setnamestring(const std::string& name) { name_ = name; std::cout << "LoweringContext~~~??>>: " << name_ << std::endl;}

const std::string& getnamestring() { return name_; }

StackFrameIndexBuilder* stack_frame_index_builder() {
return stack_frame_index_builder_.get();
}
Expand Down Expand Up @@ -121,6 +126,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
parameters_map_;
std::vector<xla::XlaOp> root_tuple_;
OutputMap<xla::XlaOp> emitted_outputs_;
std::string name_;

std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
}; // namespace torch_xla
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,14 @@ std::vector<XLATensorPtr> user_computation(
runtime::ComputationClient::ComputationPtr computation) {
XLA_CHECK(!inputs.empty());
std::vector<torch::lazy::Value> input_values;
std::vector<const torch::lazy::Node*> root_nodes;
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
torch::lazy::Value ir_value = input->GetIrValue();
input_values.push_back(ir_value);
root_nodes.push_back(ir_value.node.get());
}
std::string graph_str = DumpUtil::ToText(root_nodes);
std::cout << "inputs' torch::lazy::Node are ##@#@#@#@#: " << graph_str << std::endl;
torch::lazy::NodePtr node = torch::lazy::MakeNode<UserComputation>(
torch::lazy::OpKind::Get(opname), input_values, std::move(computation));
// Cast can be one of the user computation and we don't want to inherit the
Expand Down
Loading