diff --git a/policy_driven_attack/pd_attack.py b/policy_driven_attack/pd_attack.py index a0f0ee3..383ffbe 100644 --- a/policy_driven_attack/pd_attack.py +++ b/policy_driven_attack/pd_attack.py @@ -53,6 +53,7 @@ def __init__(self, victim_query, epsilon, external_init_adv_image, use_pytorch_r loader = DataLoaderMaker.get_test_attacked_data(dataset, batch_size, True) self.dataset_loader = loader self.total_images = len(self.dataset_loader.dataset) + self.load_random_class_image = args.load_random_class_image def calc_distance(self, x1, x2): diff = x1.cuda() - x2.cuda() @@ -64,32 +65,32 @@ def calc_distance(self, x1, x2): else: raise NotImplementedError('Unknown norm: {}'.format(self.norm_type)) - def get_image_of_target_class(self, dataset_name, target_labels): + def get_image_of_target_class(self, dataset_name, target_labels, index): images = [] for label in target_labels: # length of target_labels is 1 - if dataset_name == "ImageNet": - dataset = ImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") - elif dataset_name == "CIFAR-10": - dataset = CIFAR10Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") - elif dataset_name == "CIFAR-100": - dataset = CIFAR100Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") - elif dataset_name == "TinyImageNet": - dataset = TinyImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") - index = np.random.randint(0, len(dataset)) - image, true_label = dataset[index] - image = image.unsqueeze(0) - if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299: - image = F.interpolate(image, - size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]), - mode='bicubic', align_corners=False) - with torch.no_grad(): - logits = self.victim_query.query(image.cuda(), True, True) # --[debug] - - max_recursive_loop_limit = 100 - loop_count = 0 - while not logits.item() and loop_count < max_recursive_loop_limit: - loop_count += 1 + if self.load_random_class_image: + initial_images = np.load( + "{}/attacked_images/{}/{}_targeted-attack-initial-images.npz".format(PROJECT_PATH, dataset_name, + dataset_name), + allow_pickle=True) + image = torch.from_numpy(initial_images[str(label.item())]) + if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299: + image = F.interpolate(image, + size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]), + mode='bicubic', + align_corners=False) + with torch.no_grad(): + self.victim_query.query(image.cuda(), True, True)# --[debug] + else: + if dataset_name == "ImageNet": + dataset = ImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") + elif dataset_name == "CIFAR-10": + dataset = CIFAR10Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") + elif dataset_name == "CIFAR-100": + dataset = CIFAR100Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") + elif dataset_name == "TinyImageNet": + dataset = TinyImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation") index = np.random.randint(0, len(dataset)) image, true_label = dataset[index] image = image.unsqueeze(0) @@ -98,12 +99,26 @@ def get_image_of_target_class(self, dataset_name, target_labels): size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]), mode='bicubic', align_corners=False) with torch.no_grad(): - logits = self.victim_query.query(image.cuda(), True, True) + logits = self.victim_query.query(image.cuda(), True, True) # --[debug] + + max_recursive_loop_limit = 100 + loop_count = 0 + while not logits.item() and loop_count < max_recursive_loop_limit: + loop_count += 1 + index = np.random.randint(0, len(dataset)) + image, true_label = dataset[index] + image = image.unsqueeze(0) + if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299: + image = F.interpolate(image, + size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]), + mode='bicubic', align_corners=False) + with torch.no_grad(): + logits = self.victim_query.query(image.cuda(), True, True) - if loop_count == max_recursive_loop_limit: - # The program cannot find a valid image from the validation set. - return None - assert true_label == label.item() + if loop_count == max_recursive_loop_limit: + # The program cannot find a valid image from the validation set. + return None + assert true_label == label.item() images.append(torch.squeeze(image)) return torch.stack(images) # B,C,H,W @@ -136,7 +151,7 @@ def decision_function(self, victim, x, sync_best=True, no_count=False): return torch.cat(outs) # initialization for the attack - def initialize(self, sample, target_images, true_labels, target_labels): + def initialize(self, sample, target_images, true_labels, target_labels, batch_index): """ sample: the shape of sample is [C,H,W] without batch-size Efficient Implementation of BlendedUniformNoiseAttack in Foolbox. @@ -162,7 +177,7 @@ def initialize(self, sample, target_images, true_labels, target_labels): size=target_labels[invalid_target_index].size()).long() invalid_target_index = target_labels.eq(true_labels) - initialization = self.get_image_of_target_class(self.dataset,target_labels).squeeze() + initialization = self.get_image_of_target_class(self.dataset,target_labels, batch_index).squeeze() return initialization, 1 # assert num_eval < 1e4, "Initialization failed! Use a misclassified image as `target_image`" # Binary search to minimize l2 distance to original image. @@ -285,10 +300,11 @@ def attack_all_images(self, args): output_fields = ('grad', 'std') # make upsampler and downsampler + if args.grad_size != 0: # upsampler: grad to image; downsampler: image to grad - upsampler = lambda x: F.interpolate(x, size=victim.input_size[-1], mode="bicubic" if args.dataset=="ImageNet" else "bilinear", align_corners=True) - downsampler = lambda x: F.interpolate(x, size=args.grad_size, mode="bicubic" if args.dataset=="ImageNet" else "bilinear", align_corners=True) + upsampler = lambda x: F.interpolate(x, size=victim.input_size[-1], mode='bicubic' if args.dataset=="ImageNet" else "bilinear", align_corners=True) + downsampler = lambda x: F.interpolate(x, size=args.grad_size, mode='bicubic' if args.dataset=="ImageNet" else "bilinear", align_corners=True) else: # no resize, upsampler = downsampler = identity upsampler = downsampler = lambda x: x @@ -641,7 +657,7 @@ def do_pre_tune(adv_image_, image_, label_, target_): victim.reset(image=image, label=None, target_label=target_labels, attack_type=args.targeted, norm_type=args.norm_type) - target_images = self.get_image_of_target_class(self.dataset, target_labels) + target_images = self.get_image_of_target_class(self.dataset, target_labels, batch_index) target = target_labels if target_images is None: log.info("{}-th image cannot get a valid target class image to initialize!".format(batch_index + 1)) @@ -649,7 +665,7 @@ def do_pre_tune(adv_image_, image_, label_, target_): else: target_labels = None target_images = None - init_adv_image, num_eval = self.initialize(image, target_images, label, target_labels) + init_adv_image, num_eval = self.initialize(image, target_images, label, target_labels, batch_index) ########################## if init_adv_image is None: log.info('Initial point not found, {}-th image: image_id: {}, skip this image'.format( @@ -1674,12 +1690,13 @@ def parse_args(): parser.add_argument('--target_type', type=str, default='increment', choices=['random', "load_random", 'least_likely', "increment"]) parser.add_argument('--all-archs', action="store_true") - parser.add_argument('--json-config', type=str, default='../configures/PDA.json', + parser.add_argument('--json-config', type=str, default='{}/TangentAttack-main/configures/PDA.json'.format(PROJECT_PATH), help='a configures file to be passed in instead of arguments') parser.add_argument('--ssh', action='store_true', help='whether or not we are executing command via ssh.' 'If set to True, we will not print anything to screen and only redirect them to log file') # used - + parser.add_argument('--load-random-class-image', action='store_true', + help='load a random image from the target class') # npz {"0":, "1": ,"2": } if len(sys.argv) == 1: parser.print_help() sys.exit(1)