diff --git a/test/test_test_mnist.py b/test/test_test_mnist.py index dd615b7705a4..54036c86bca5 100644 --- a/test/test_test_mnist.py +++ b/test/test_test_mnist.py @@ -257,30 +257,32 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): # bias = simple_with_linear.bias # not be used actually, initialized as placeholder xlacomputation requirement res = [upper.clone(), new_lower.clone(), one_value.clone(), torch.add(one_value, x), input_value.clone(), output_value.clone()] bn_list = [] - bn_flag = False + # bn_flag = False for name, param in simple_with_linear.named_parameters(): if name[:2]=='bn': - bn_flag = True + # bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn - else: - bn_flag = False + # else: + # bn_flag = False res.insert(-1, param) - if (not bn_flag) and (len(bn_list) !=0): # False - output_value = res[-1] - res = res[:-1] + bn_list # + res[-1] - res.append(output_value) - bn_list = [] - # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) + # if (not bn_flag) and (len(bn_list) !=0): # False + # output_value = res[-1] + # res = res[:-1] + bn_list # + res[-1] + # res.append(output_value) + # bn_list = [] + # # torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) ### !!! add still exist bn_list if the last additional_inputs is bn- pre - if bn_flag and (len(bn_list) !=0): + # if bn_flag and (len(bn_list) !=0): + ### !!! add at the tile + if len(bn_list) !=0: output_value = res[-1] res = res[:-1] + bn_list # + res[-1] res.append(output_value) bn_list = [] - bn_flag = False + # bn_flag = False return tuple(res) # return (upper.clone(), new_lower.clone(), one_value.clone(), torch.add( @@ -321,32 +323,34 @@ def body_fn(upper, lower, one_value, x, input_value, output_value, *args): additional_inputs = [] bn_list = [] - bn_flag = False + # bn_flag = False for name, param in simple_with_linear.named_parameters(): # if name[:2]=='bn': # additional_inputs.append(param) # dumplicate if name[:2]=='bn': # print("catch: ", name) - bn_flag = True + # bn_flag = True bn_list.insert(-1, param) # dumpicate # continue # skip bn # print("newest bn_list: ", bn_list) - else: - bn_flag = False + # else: + # bn_flag = False # additional_inputs.insert(-1, param) additional_inputs.append(param) - if (not bn_flag) and (len(bn_list) !=0): # False - additional_inputs =additional_inputs + bn_list - # print("added bn_list: ", bn_list) - bn_list = [] + # if (not bn_flag) and (len(bn_list) !=0): # False + # additional_inputs =additional_inputs + bn_list + # # print("added bn_list: ", bn_list) + # bn_list = [] ### !!! add still exist bn_list if the last additional_inputs is bn- pre - if bn_flag and (len(bn_list) !=0): - additional_inputs =additional_inputs + bn_list + # if bn_flag and (len(bn_list) !=0): + ### !!! add duplicated bn argus as the tile of the list + if len(bn_list) !=0: + additional_inputs = additional_inputs + bn_list # print("added bn_list: ", bn_list) bn_list = [] - bn_flag = False + # bn_flag = False print("final additional_inputs: ", additional_inputs) # print("in mnist additional_inputs: ", additional_inputs)