Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Apr 30, 2024
1 parent 95227f0 commit 06cb773
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions test/test_test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,30 +257,32 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):
# bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement
res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()]
bn_list = []
bn_flag = False
# bn_flag = False
for name, param in simple_with_linear.named_parameters():
if name[:2]=='bn':
bn_flag = True
# bn_flag = True
bn_list.insert(-1, param) # dumpicate # continue # skip bn
else:
bn_flag = False
# else:
# bn_flag = False

res.insert(-1, param)

if (not bn_flag) and (len(bn_list) !=0): # False
output_value = res[-1]
res = res[:-1] + bn_list # + res[-1]
res.append(output_value)
bn_list = []
# torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device))
# if (not bn_flag) and (len(bn_list) !=0): # False
# output_value = res[-1]
# res = res[:-1] + bn_list # + res[-1]
# res.append(output_value)
# bn_list = []
# # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device))

### !!! add still exist bn_list if the last additional_inputs is bn- pre
if bn_flag and (len(bn_list) !=0):
# if bn_flag and (len(bn_list) !=0):
### !!! add at the tile
if len(bn_list) !=0:
output_value = res[-1]
res = res[:-1] + bn_list # + res[-1]
res.append(output_value)
bn_list = []
bn_flag = False
# bn_flag = False

return tuple(res)
# return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add(
Expand Down Expand Up @@ -321,32 +323,34 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args):

additional_inputs = []
bn_list = []
bn_flag = False
# bn_flag = False
for name, param in simple_with_linear.named_parameters():
# if name[:2]=='bn':
# additional_inputs.append(param) # dumplicate
if name[:2]=='bn':
# print("catch: ", name)
bn_flag = True
# bn_flag = True
bn_list.insert(-1, param) # dumpicate # continue # skip bn
# print("newest bn_list: ", bn_list)
else:
bn_flag = False
# else:
# bn_flag = False

# additional_inputs.insert(-1, param)
additional_inputs.append(param)

if (not bn_flag) and (len(bn_list) !=0): # False
additional_inputs =additional_inputs + bn_list
# print("added bn_list: ", bn_list)
bn_list = []
# if (not bn_flag) and (len(bn_list) !=0): # False
# additional_inputs =additional_inputs + bn_list
# # print("added bn_list: ", bn_list)
# bn_list = []

### !!! add still exist bn_list if the last additional_inputs is bn- pre
if bn_flag and (len(bn_list) !=0):
additional_inputs =additional_inputs + bn_list
# if bn_flag and (len(bn_list) !=0):
### !!! add duplicated bn argus as the tile of the list
if len(bn_list) !=0:
additional_inputs = additional_inputs + bn_list
# print("added bn_list: ", bn_list)
bn_list = []
bn_flag = False
# bn_flag = False

print("final additional_inputs: ", additional_inputs)
# print("in mnist additional_inputs: ", additional_inputs)
Expand Down

0 comments on commit 06cb773

Please sign in to comment.