Skip to content

Commit

Permalink
passed add, sub and linear
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed May 20, 2024
1 parent a6677af commit 76c2e7a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 35 deletions.
130 changes: 95 additions & 35 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def body_fn(iteri, x):
# print("expected: ", expected)
# print("l_in_0: ", l_in_0)

def test_while_loop_tpu_MNIST_target_inside_loop_may16_2238pm(self):
def test_while_loop_tpu_MNIST_target_inside_loop_may19_2300pm(self):
xm.mark_step()
device = xm.xla_device()
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -375,57 +375,117 @@ def __init__(self):
self.bn2 = torch.nn.BatchNorm2d(20)
self.fc1 = torch.nn.Linear(500, 50)
self.fc2 = torch.nn.Linear(50, 10)
self.weight_bias_lists = []
self.bn_weight_bias_lists = []
# self.register_buffer("dec", torch.tensor(1))

def forward(self, iteri, x, y):

def cond_fn(iteri, x, y):
return iteri > 0

def body_fn(iteri, x, y):
def body_fn(iteri, x):

y = F.relu(F.max_pool2d(self.conv1(x), 2))
y = self.bn1(y)
y = F.relu(F.max_pool2d(self.conv2(y), 2))
y = self.bn2(y)
y = torch.flatten(y, 1)
y = F.relu(self.fc1(y))
y = self.fc2(y)

# add layers para manually
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv2.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc1.named_parameters())
insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc2.named_parameters())
# y = self.bn1(y)
# y = F.relu(F.max_pool2d(self.conv2(y), 2))
# y = self.bn2(y)
# y = torch.flatten(y, 1)
# y = F.relu(self.fc1(y))
# y = self.fc2(y)

insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())
# return iteri - 1, F.log_softmax(y, dim=1)
return iteri - 1, y

return iteri - 1, x, F.log_softmax(y, dim=1)

return _xla_while_loop_target_first_second_clean_version(cond_fn, body_fn, (iteri, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# return _xla_while_loop_target_first_second_clean_version_s32(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# return _xla_while_loop_target_first_second_clean_version_s32_old(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# return _xla_while_loop_target_first_second_clean_version_s32_may16_1603pm(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
return while_loop(cond_fn, body_fn, (iteri, x))

mnist = MNIST()
mnist.to(device)
bs=16
l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device)
l_out = torch.randn(bs, 10, dtype=torch.float32, device=device)
# l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device)
l_in_0 = torch.randn(16, 1, 28, 28, dtype=torch.float32, device=device)
iteri = torch.tensor(3, dtype=torch.int64, device=device)
# print("dtype iter: ", iter.dtype())
# print("type iter: ", iter.type())
# print("dtype ite 1: ", torch.dtype(iter))
# print("dtype iter 2: ", iter.type().dtype)
# print("dtype iter 3: ", iter.dtype)
res = mnist(iteri, l_in_0, l_out)
res = mnist(iteri, l_in_0)
print("res: ", res[-1])


# def test_while_loop_tpu_MNIST_target_inside_loop_may16_2238pm(self):
# xm.mark_step()
# device = xm.xla_device()
# torch.set_grad_enabled(False)

# n_epochs = 3
# batch_size_train = 8 # 64
# batch_size_test = 10 # 1000
# learning_rate = 0.01
# momentum = 0.5
# log_interval = 10
# random_seed = 1
# torch.backends.cudnn.enabled = False
# torch.manual_seed(random_seed)

# ### load data
# test_loader = xu.SampleGenerator(
# data=(torch.zeros(8, 1, 28,28), torch.zeros(8, dtype=torch.int64)),
# sample_count=1000 // 8 // xm.xrt_world_size())

# class MNIST(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5, stride=1, padding=2)
# self.bn1 = torch.nn.BatchNorm2d(10)
# self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
# self.bn2 = torch.nn.BatchNorm2d(20)
# self.fc1 = torch.nn.Linear(500, 50)
# self.fc2 = torch.nn.Linear(50, 10)
# self.weight_bias_lists = []
# self.bn_weight_bias_lists = []
# # self.register_buffer("dec", torch.tensor(1))

# def forward(self, iteri, x, y):

# def cond_fn(iteri, x, y):
# return iteri > 0

# def body_fn(iteri, x, y):

# y = F.relu(F.max_pool2d(self.conv1(x), 2))
# y = self.bn1(y)
# y = F.relu(F.max_pool2d(self.conv2(y), 2))
# y = self.bn2(y)
# y = torch.flatten(y, 1)
# y = F.relu(self.fc1(y))
# y = self.fc2(y)

# # add layers para manually
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv1.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn1.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.conv2.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc1.named_parameters())
# insert_model_pars_into_additional_inputs(self.weight_bias_lists, self.fc2.named_parameters())

# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn2.named_parameters())
# insert_model_pars_into_additional_inputs(self.bn_weight_bias_lists, self.bn1.named_parameters())

# return iteri - 1, x, F.log_softmax(y, dim=1)

# return _xla_while_loop_target_first_second_clean_version(cond_fn, body_fn, (iteri, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# # return _xla_while_loop_target_first_second_clean_version_s32(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# # return _xla_while_loop_target_first_second_clean_version_s32_old(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)
# # return _xla_while_loop_target_first_second_clean_version_s32_may16_1603pm(cond_fn, body_fn, (iter, x, y), self.weight_bias_lists, self.bn_weight_bias_lists)

# mnist = MNIST()
# mnist.to(device)
# bs=16
# l_in_0 = torch.randn(bs, 1, 28, 28, dtype=torch.float32, device=device)
# l_out = torch.randn(bs, 10, dtype=torch.float32, device=device)
# iteri = torch.tensor(3, dtype=torch.int64, device=device)
# # print("dtype iter: ", iter.dtype())
# # print("type iter: ", iter.type())
# # print("dtype ite 1: ", torch.dtype(iter))
# # print("dtype iter 2: ", iter.type().dtype)
# # print("dtype iter 3: ", iter.dtype)
# res = mnist(iteri, l_in_0, l_out)
# print("res: ", res[-1])

# def test_while_loop_tpu_simple_linear_target_inside_loop_may16_2238pm(self):
# xm.mark_step()
# device = xm.xla_device()
Expand Down
1 change: 1 addition & 0 deletions torch_xla/experimental/fori_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
# TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
# cond_fn&body_fn: callable
# carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
print("dispatchkey here !!!")
if additional_inputs is None:
additional_inputs = tuple()
# print("arrive @while_loop_op.py_impl(DispatchKey.XLA)")
Expand Down

0 comments on commit 76c2e7a

Please sign in to comment.