Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 9, 2024
1 parent 43527e6 commit 3ae7c97
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,47 +112,12 @@ def test_while_loop_get_xlacomputation(self):
print("before run _get_xla_computation: !!!!!!!!!")
res_xla_computation = torch_xla._XLAC._get_xla_computation([t3], [], True)
print("after run _get_xla_computation: !!!!!!!!!")
if res_xla_computation == null:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")
else:
if res_xla_computation:
hlo_print = xb.get_computation_hlo(res_xla_computation)
print("print computation from _get_xla_computation: !!!!!!!!!")
print(hlo_print)

# ### test to see what _get_stablehlo would get
# # tensors, torch_xla._XLAC._xla_get_default_device(), [],
# res_hlo_xla_computation = torch_xla._XLAC._get_stablehlo([t3], torch_xla._XLAC._xla_get_default_device(), [], False)
# print("res hlo: ", res_hlo_xla_computation)

# ### test to get xlacomputation
# body_result0 = torch.add(t1, t2)
# body_ctx0 = torch_xla._XLAC.lowering.LoweringContext()
# body_ctx0.build(list(body_result0))
# body_hlo0 = body_ctx0.hlo()
# print("body_hlo0 finish !!!")
# print("type body_hlo0: ", type(body_hlo0))
# # print("body_hlo0: ", body_hlo0)

# ### test to use _get_stablehlo to get xlacomputation
# body_result = torch.add(t1, t2)
# # body_ctx = torch_xla._XLAC.lowering.LoweringContext()
# body_stable_hlo = torch_xla._XLAC._get_stablehlo([t3], torch_xla._XLAC._xla_get_default_device(), [], False)
# print("body_stable_hlo finish !!!")
# print("type body_stable_hlo: ", type(body_stable_hlo))
# print("body_stable_hlo: ", body_stable_hlo)

# if body_hlo0 == body_stable_hlo:
# print("hlo and stablehlo are the same iteam")
# else:
# print("hlo and stablehlo are not the same iteam")
# # body_ctx.set_name_string("bodyctx")
# # body_ctx.buildforiloop(list(body_result), [])
# # body_hlo = body_ctx.hlo()
# body_computation = xb.computation_from_module_proto("bodycomputation",
# body_stable_hlo)
# hlo_print = xb.get_computation_hlo(body_computation)
# print("print computation from _get_stablehlo: !!!!!!!!!")
# print(hlo_print)
else:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

def test_while_loop_get_xlacomputation_directly(self):

Expand All @@ -177,7 +142,6 @@ def test_while_loop_get_xlacomputation_tpu_simple_linear_without_while_loop(self

xm.mark_step()
device = xm.xla_device()
#device = ''
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):
Expand All @@ -189,24 +153,11 @@ def __init__(self):
def forward(self, x):
x = self.linear(x)
return x
# def cond_fn(it, x):
# return it - self.dec > 0

# def body_fn(it, x):
# return it - 1, self.linear(x)

# return while_loop(cond_fn, body_fn, (iter, x))

simple_with_linear = SimpleWithLinear()
simple_with_linear.to(device)
#breakpoint()
input = torch.randn(2, 2).to(device)
# iter = torch.tensor(3, device=device)
# res = simple_with_linear(iter, input)
t3 = simple_with_linear(input)
# t1 = torch.randn(20, 5).to(device)
# t2 = torch.randn(20, 5).to(device)
# t3 = torch.add(t1, t2)

### implement one new function for xlacomputation generation with post-order
print("before run _get_xla_computation: !!!!!!!!!")
Expand All @@ -219,11 +170,11 @@ def forward(self, x):
else:
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

# this test should be modified/enabled after merge with the PR #6867
def test_while_loop_get_xlacomputation_tpu_simple_linear_while_loop(self):

xm.mark_step()
device = xm.xla_device()
#device = ''
torch.set_grad_enabled(False)

class SimpleWithLinear(torch.nn.Module):
Expand All @@ -235,24 +186,11 @@ def __init__(self):
def forward(self, x):
x = self.linear(x)
return x
# def cond_fn(it, x):
# return it - self.dec > 0

# def body_fn(it, x):
# return it - 1, self.linear(x)

# return while_loop(cond_fn, body_fn, (iter, x))

simple_with_linear = SimpleWithLinear()
simple_with_linear.to(device)
#breakpoint()
input = torch.randn(2, 2).to(device)
# iter = torch.tensor(3, device=device)
# res = simple_with_linear(iter, input)
t3 = simple_with_linear(input)
# t1 = torch.randn(20, 5).to(device)
# t2 = torch.randn(20, 5).to(device)
# t3 = torch.add(t1, t2)

def cond_fn(upper, lower, one_value, x, input_value, output_value, *args):
return lower[0] < upper[0]
Expand All @@ -262,22 +200,6 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
output_value = simple_with_linear(input_value)
res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()]
return tuple(res)
# bn_list = []
# for name, param in simple_with_linear.named_parameters():
# if name[:2]=='bn':
# bn_list.append(param)

# res.insert(-1, param)

# # add still exist bn_list if the last additional_inputs is bn- pre
# # add at the tile
# if len(bn_list) !=0:
# output_value = res[-1]
# bn_list.reverse()
# res = res[:-1] + bn_list
# res.append(output_value)
# bn_list = []
# return tuple(res)

### implement one new function for xlacomputation generation with post-order
print("before run _get_xla_computation: !!!!!!!!!")
Expand All @@ -291,10 +213,8 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
print("print computation from _get_xla_computation: null !!!!!!!!!!!!!")

### get xlacomputation via PyLoweringContext
# body_result = body_fn(*fake_carried_inputs)
body_ctx = torch_xla._XLAC.lowering.LoweringContext()
body_ctx.set_name_string("bodyctx")
# body_ctx.buildforiloop(list(body_result), [])
body_ctx.buildforiloop(list(t3), [])
body_hlo = body_ctx.hlo()
body_computation = xb.computation_from_module_proto("bodycomputation",
Expand Down
Loading

0 comments on commit 3ae7c97

Please sign in to comment.