Skip to content

Commit

Permalink
down into cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 24, 2024
1 parent f11daa0 commit 3aa67e7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,68 @@ def body_fn(upper, lower, one_value, x, input_value, output_value):
self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

###
#
def test_while_loop_tpu_simple_linear_class_wrapper(self):

print("$$$ test_while_loop_tpu_simple_linear_class_wrapper !!!")
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 20).to(xm.xla_device())

def forward(self, upper, lower, one_value, x, input_value, output_value):

def cond_fn(upper, lower, one_value, x, input_value, output_value):
return lower[0] < upper[0]

def body_fn(upper, lower, one_value, x, input_value, output_value):
new_lower = torch.add(one_value, lower)
output_value_real = self.linear(input_value)
weight = self.linear.weight # not be used actually, initialized as placeholder xlacomputation requirement
bias = self.linear.bias # not be used actually, initialized as placeholder xlacomputation requirement
return upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
one_value, x), input_value.clone(
), output_value_real

return while_loop(
cond_fn, body_fn,
(upper, lower, one_value, x, input_value, output_value))

simple_with_linear = SimpleWithLinear()
upper = torch.tensor([52], dtype=torch.int32, device=device)
lower = torch.tensor([0], dtype=torch.int32, device=device)
one_value = torch.tensor([1], dtype=torch.int32, device=device)
init_val = torch.tensor([1], dtype=torch.int32, device=device)
l_in_0 = torch.rand(10, device=xm.xla_device())
output_value = torch.zeros([20], dtype=torch.float32, device=device)

weight_0 = simple_with_linear.linear.weight
bias_0 = simple_with_linear.linear.bias

aaa = {
"simple_with_linear":
(simple_with_linear, (upper, lower, one_value, init_val, l_in_0,
output_value))
}

upper__, lower__, one_value__, torch_add_res__, input_value__, output_value_real__, weight__, bias__ = simple_with_linear(
upper, lower, one_value, init_val, l_in_0, output_value)

# create same weight/bias liear model for compare
linear_0 = torch.nn.Linear(10, 20).to(xm.xla_device())
linear_0.weight.data = weight__
linear_0.bias.data = bias__
expected = _fake_fori_loop(lower, upper, linear_0, l_in_0)

self.assertTrue(torch.all(torch.eq(expected, output_value_real__)))
return aaa

# additional_inputs: ()
def test_fori_loop_tpu_addition(self):

Expand Down
1 change: 1 addition & 0 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def new_body_fn(*carried_inputs):
# res.extend(additional_inputs)
# print("res: ", res)
# return list(body_fn(*carried_inputs)).extend(additional_inputs)
self.named_parameters
res = list(body_fn(*carried_inputs))
# print("res: ", res)
# trynewres = res[:-1] + [res[-1]]
Expand Down

0 comments on commit 3aa67e7

Please sign in to comment.