Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fori_loop|While_loop] Enable fori_loop with add/sub test case #6603

Merged
merged 64 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
4604027
Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
ManfeiBai Feb 23, 2024
325341b
Update xla_builder.py
ManfeiBai Feb 23, 2024
021fb49
Update xla_op_registry.py
ManfeiBai Feb 23, 2024
ec7b536
Update helpers.cpp
ManfeiBai Feb 23, 2024
40056a7
Update init_python_bindings.cpp
ManfeiBai Feb 23, 2024
a4d2b58
Update lowering_context.cpp
ManfeiBai Feb 23, 2024
8955e82
Update lowering_context.h
ManfeiBai Feb 23, 2024
cfd1118
Update tensor_methods.cpp
ManfeiBai Feb 23, 2024
6362c07
Update fori_loop.py
ManfeiBai Feb 23, 2024
3436610
clean xla_builder.py
ManfeiBai Feb 23, 2024
70592ad
Update xla_op_registry.py
ManfeiBai Feb 23, 2024
c299607
clean helpers.cpp
ManfeiBai Feb 23, 2024
43988b6
Update init_python_bindings.cpp
ManfeiBai Feb 23, 2024
5d80975
clean lowering_context.cpp
ManfeiBai Feb 23, 2024
09eac90
clean lowering_context.h
ManfeiBai Feb 23, 2024
9a72283
clean tensor_methods.cpp
ManfeiBai Feb 23, 2024
f3b987f
Update fori_loop.py
ManfeiBai Feb 23, 2024
6e6ebcd
Update tensor_methods.cpp
ManfeiBai Feb 23, 2024
5a0ea18
Update init_python_bindings.cpp
ManfeiBai Feb 23, 2024
c1b7c7a
Update lowering_context.cpp
ManfeiBai Feb 23, 2024
7cc755e
format
ManfeiBai Feb 23, 2024
6470ec4
format
ManfeiBai Feb 23, 2024
2cbead2
format
ManfeiBai Feb 23, 2024
b05bf08
try torch func in simple case
ManfeiBai Feb 24, 2024
955e658
Update init_python_bindings.cpp
ManfeiBai Feb 24, 2024
c568fa2
format
ManfeiBai Feb 24, 2024
1ae51f0
update to clean code
ManfeiBai Feb 26, 2024
6e3457f
wrap in new func
ManfeiBai Feb 26, 2024
0eb1aa4
check for test
ManfeiBai Feb 26, 2024
a7621d8
check for test bring back
ManfeiBai Feb 26, 2024
19e977b
Update init_python_bindings.cpp
ManfeiBai Feb 27, 2024
17d2100
Update init_python_bindings.cpp
ManfeiBai Feb 27, 2024
477cdb2
Update fori_loop.py
ManfeiBai Feb 27, 2024
acd0caf
Update fori_loop.py
ManfeiBai Feb 27, 2024
ed7aa1c
Update init_python_bindings.cpp
ManfeiBai Feb 29, 2024
f996c4c
Update init_python_bindings.cpp
ManfeiBai Feb 29, 2024
35878ec
Update fori_loop.py
ManfeiBai Feb 29, 2024
abe3067
remove mark_step
ManfeiBai Feb 29, 2024
56f5cb0
Update lowering_context.cpp
ManfeiBai Feb 29, 2024
b5714f7
Update lowering_context.h
ManfeiBai Feb 29, 2024
e7fa408
Update init_python_bindings.cpp
ManfeiBai Feb 29, 2024
66402ef
Update lowering_context.cpp
ManfeiBai Feb 29, 2024
6b1a570
Update lowering_context.h
ManfeiBai Feb 29, 2024
40d39fa
Update lowering_context.h
ManfeiBai Feb 29, 2024
93638c2
Update fori_loop.py
ManfeiBai Feb 29, 2024
28a3dd7
Update fori_loop.py
ManfeiBai Feb 29, 2024
e014c1d
refactor logic of check while_loop
ManfeiBai Mar 4, 2024
4e6f076
update comment
ManfeiBai Mar 4, 2024
dafa8fb
format
ManfeiBai Mar 4, 2024
1b952cb
format
ManfeiBai Mar 4, 2024
8ec55c3
format
ManfeiBai Mar 4, 2024
aeb0a58
add test case
ManfeiBai Mar 4, 2024
02bf5ba
add test case
ManfeiBai Mar 4, 2024
d51e069
add test case
ManfeiBai Mar 5, 2024
5f3af2b
Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
ManfeiBai Mar 5, 2024
cb26670
test case for add and sub
ManfeiBai Mar 7, 2024
f37149b
format
ManfeiBai Mar 7, 2024
a48a3cf
format
ManfeiBai Mar 7, 2024
f30eaab
wrap test situation check
ManfeiBai Mar 7, 2024
1633d8e
wrap test situation check
ManfeiBai Mar 7, 2024
8a3b9fe
wrap test situation check
ManfeiBai Mar 7, 2024
f6fdc78
format
ManfeiBai Mar 7, 2024
5c2a2a0
Update test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
ManfeiBai Mar 7, 2024
173ff44
nit
ManfeiBai Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,47 @@


def _fake_while_loop(cond_fn, body_fn, operands):
while cond_fn(*operands):
operands = body_fn(*operands)
while cond_fn(operands[0], operands[1]):
operands = body_fn(operands[0], operands[1])
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_tpu(self):
def test_while_loop_tpu_subtraction(self):

def cond_fn(x):
return x.sum() <= 10
device = xm.xla_device()

def cond_fn(init, limit_value):
return limit_value[0] <= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
two_value = limit_value.clone()
return (torch.sub(init, one_value), two_value)

init = torch.tensor([10], dtype=torch.int32, device=device)
limit_value = torch.tensor([0], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)

def body_fn(x):
return (x + 1,)
def test_while_loop_tpu_addition(self):

device = xm.xla_device()
x = torch.ones(1, dtype=torch.int, device=device)
res = while_loop(cond_fn, body_fn, (x,))
expected = _fake_while_loop(cond_fn, body_fn, x)

def cond_fn(init, limit_value):
return limit_value[0] >= init[0]

def body_fn(init, limit_value):
one_value = torch.ones(1, dtype=torch.int32, device=device)
return (torch.add(init, one_value), limit_value.clone())

# TODO(@manfei): init and limit_value has to be torch.tensor.
init = torch.tensor([0], dtype=torch.int32, device=device)
limit_value = torch.tensor([10], dtype=torch.int32, device=device)
res = while_loop(cond_fn, body_fn, (init, limit_value))
expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
self.assertEqual(expected, res)


Expand Down
26 changes: 25 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,22 @@ class PyLoweringContext {
lowering_ctx.AddResult(root);
}
computation = ConsumeValue(lowering_ctx.BuildXla());

// wrap inputs of cond/body_computation
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
std::vector<size_t> buffer_donor_indices;
xla::ProgramShape program_shape =
ConsumeValue(computation.GetProgramShape());
// TODO(@manfei): please confirm whether we check for more than two or use
// default value true
bool should_wrap_parameter = (program_shape.parameters_size() >= 2);
if (should_wrap_parameter) {
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
computation, program_shape.parameters(), input_output_alias_pair,
buffer_donor_indices));
}
}
}

// Get a mapping from the HLO input parameters to the backing Tensor values.
Expand Down Expand Up @@ -983,6 +999,12 @@ class PyLoweringContext {
return result;
}

void SetNameString(const std::string& name) {
lowering_ctx.set_name_string(name);
}

std::string GetNameString() { return lowering_ctx.get_name_string(); }

private:
LoweringContext lowering_ctx;
xla::XlaComputation computation;
Expand Down Expand Up @@ -1027,7 +1049,9 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, thanks!

.def("get_name_string", &PyLoweringContext::GetNameString);
}

void InitXlaModuleBindings(py::module m) {
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) {

xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
xla::StatusOr<xla::XlaComputation> xla;
if (!root_tuple_.empty()) {

// check whether build for cond/body computation or not, and skip Tuple step
// if yes
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

untuple result of body_fn&cond_fn for xla::while op

if (!root_tuple_.empty() & (root_tuple_.size() == 1) &
((get_name_string() == "condctx") or (get_name_string() == "bodyctx"))) {
xla = builder()->Build(root_tuple_.at(0));
} else if (!root_tuple_.empty()) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
xla = builder()->Build(root);
} else {
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class LoweringContext : public torch::lazy::LoweringContext {

xla::XlaBuilder* builder() { return &builder_; }

void set_name_string(const std::string& name) { name_ = name; }

const std::string& get_name_string() { return name_; }

StackFrameIndexBuilder* stack_frame_index_builder() {
return stack_frame_index_builder_.get();
}
Expand Down Expand Up @@ -121,6 +125,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
parameters_map_;
std::vector<xla::XlaOp> root_tuple_;
OutputMap<xla::XlaOp> emitted_outputs_;
std::string name_;

std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
}; // namespace torch_xla
Expand Down
60 changes: 41 additions & 19 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,44 @@ def while_loop(cond_fn, body_fn, operands):

def _xla_while_loop(cond_fn, body_fn, operands):

def op_fn(internal_x):
# TODO(manfei) replace cond_fn_placeholder and body_fn_placeholder once xla::while lowering in the backend is available
def cond_fn_placeholder(counter, internal_x):
return counter < xb.Op.scalar(internal_x.builder(), 10, dtype=xb.Type.S32)

def body_fn_placeholder(counter, internal_x):
next_counter = counter + xb.Op.scalar(
counter.builder(), 1, dtype=xb.Type.S32)
internal_x = internal_x + xb.Op.scalar(
internal_x.builder(), 1, dtype=xb.Type.S32)
return xb.Op.tuple((next_counter, internal_x))

zero = xb.Op.scalar(internal_x.builder(), 0, dtype=xb.Type.S32)
w = xb.Op.mkwhile((zero, internal_x), cond_fn_placeholder,
body_fn_placeholder)
return w.get_tuple_element(1)

op = xor.register('test_while', op_fn)
return xu.as_list(op(operands[0]))
# create inputs placeholder
kwargs = {}
shapes = xb.tensor_shape(operands)
builder = xb.create_builder('test_while')
params = []
for shape in shapes:
p = xb.mkparam(builder, len(params), shape)
params.append(p)

# generate cond_fn xlacomputation
cond_result = cond_fn(operands[0], operands[1])
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
cond_ctx.build([cond_result])
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)

# generate body_fn xlacomputation
body_result = body_fn(operands[0], operands[1])
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
body_ctx.build(list(body_result))
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
body_hlo)

# generate while xlacomputation
input_tuple = xb.Op.tuple(tuple(params))
w = xb.mkop(
'While', (input_tuple.op,),
condition_computation=cond_computation,
body_computation=body_computation)
name = 'fori_loop_ed_torch_func'
computation = w.build(name)

# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
tuple(operands), computation)

return result
Loading