Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20230731/add pda to bug torch #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions policy_driven_attack/pd_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, victim_query, epsilon, external_init_adv_image, use_pytorch_r
loader = DataLoaderMaker.get_test_attacked_data(dataset, batch_size, True)
self.dataset_loader = loader
self.total_images = len(self.dataset_loader.dataset)
self.load_random_class_image = args.load_random_class_image

def calc_distance(self, x1, x2):
diff = x1.cuda() - x2.cuda()
Expand All @@ -64,32 +65,32 @@ def calc_distance(self, x1, x2):
else:
raise NotImplementedError('Unknown norm: {}'.format(self.norm_type))

def get_image_of_target_class(self, dataset_name, target_labels):
def get_image_of_target_class(self, dataset_name, target_labels, index):

images = []
for label in target_labels: # length of target_labels is 1
if dataset_name == "ImageNet":
dataset = ImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "CIFAR-10":
dataset = CIFAR10Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "CIFAR-100":
dataset = CIFAR100Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "TinyImageNet":
dataset = TinyImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
index = np.random.randint(0, len(dataset))
image, true_label = dataset[index]
image = image.unsqueeze(0)
if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299:
image = F.interpolate(image,
size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]),
mode='bicubic', align_corners=False)
with torch.no_grad():
logits = self.victim_query.query(image.cuda(), True, True) # --[debug]

max_recursive_loop_limit = 100
loop_count = 0
while not logits.item() and loop_count < max_recursive_loop_limit:
loop_count += 1
if self.load_random_class_image:
initial_images = np.load(
"{}/attacked_images/{}/{}_targeted-attack-initial-images.npz".format(PROJECT_PATH, dataset_name,
dataset_name),
allow_pickle=True)
image = torch.from_numpy(initial_images[str(label.item())])
if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299:
image = F.interpolate(image,
size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]),
mode='bicubic',
align_corners=False)
with torch.no_grad():
self.victim_query.query(image.cuda(), True, True)# --[debug]
else:
if dataset_name == "ImageNet":
dataset = ImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "CIFAR-10":
dataset = CIFAR10Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "CIFAR-100":
dataset = CIFAR100Dataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
elif dataset_name == "TinyImageNet":
dataset = TinyImageNetDataset(IMAGE_DATA_ROOT[dataset_name], label.item(), "validation")
index = np.random.randint(0, len(dataset))
image, true_label = dataset[index]
image = image.unsqueeze(0)
Expand All @@ -98,12 +99,26 @@ def get_image_of_target_class(self, dataset_name, target_labels):
size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]),
mode='bicubic', align_corners=False)
with torch.no_grad():
logits = self.victim_query.query(image.cuda(), True, True)
logits = self.victim_query.query(image.cuda(), True, True) # --[debug]

max_recursive_loop_limit = 100
loop_count = 0
while not logits.item() and loop_count < max_recursive_loop_limit:
loop_count += 1
index = np.random.randint(0, len(dataset))
image, true_label = dataset[index]
image = image.unsqueeze(0)
if dataset_name == "ImageNet" and self.victim_query.net.input_size[-1] != 299:
image = F.interpolate(image,
size=(self.victim_query.net.input_size[-2], self.victim_query.net.input_size[-1]),
mode='bicubic', align_corners=False)
with torch.no_grad():
logits = self.victim_query.query(image.cuda(), True, True)

if loop_count == max_recursive_loop_limit:
# The program cannot find a valid image from the validation set.
return None
assert true_label == label.item()
if loop_count == max_recursive_loop_limit:
# The program cannot find a valid image from the validation set.
return None
assert true_label == label.item()
images.append(torch.squeeze(image))

return torch.stack(images) # B,C,H,W
Expand Down Expand Up @@ -136,7 +151,7 @@ def decision_function(self, victim, x, sync_best=True, no_count=False):
return torch.cat(outs)

# initialization for the attack
def initialize(self, sample, target_images, true_labels, target_labels):
def initialize(self, sample, target_images, true_labels, target_labels, batch_index):
"""
sample: the shape of sample is [C,H,W] without batch-size
Efficient Implementation of BlendedUniformNoiseAttack in Foolbox.
Expand All @@ -162,7 +177,7 @@ def initialize(self, sample, target_images, true_labels, target_labels):
size=target_labels[invalid_target_index].size()).long()
invalid_target_index = target_labels.eq(true_labels)

initialization = self.get_image_of_target_class(self.dataset,target_labels).squeeze()
initialization = self.get_image_of_target_class(self.dataset,target_labels, batch_index).squeeze()
return initialization, 1
# assert num_eval < 1e4, "Initialization failed! Use a misclassified image as `target_image`"
# Binary search to minimize l2 distance to original image.
Expand Down Expand Up @@ -285,10 +300,11 @@ def attack_all_images(self, args):
output_fields = ('grad', 'std')

# make upsampler and downsampler

if args.grad_size != 0:
# upsampler: grad to image; downsampler: image to grad
upsampler = lambda x: F.interpolate(x, size=victim.input_size[-1], mode="bicubic" if args.dataset=="ImageNet" else "bilinear", align_corners=True)
downsampler = lambda x: F.interpolate(x, size=args.grad_size, mode="bicubic" if args.dataset=="ImageNet" else "bilinear", align_corners=True)
upsampler = lambda x: F.interpolate(x, size=victim.input_size[-1], mode='bicubic' if args.dataset=="ImageNet" else "bilinear", align_corners=True)
downsampler = lambda x: F.interpolate(x, size=args.grad_size, mode='bicubic' if args.dataset=="ImageNet" else "bilinear", align_corners=True)
else:
# no resize, upsampler = downsampler = identity
upsampler = downsampler = lambda x: x
Expand Down Expand Up @@ -641,15 +657,15 @@ def do_pre_tune(adv_image_, image_, label_, target_):

victim.reset(image=image, label=None, target_label=target_labels,
attack_type=args.targeted, norm_type=args.norm_type)
target_images = self.get_image_of_target_class(self.dataset, target_labels)
target_images = self.get_image_of_target_class(self.dataset, target_labels, batch_index)
target = target_labels
if target_images is None:
log.info("{}-th image cannot get a valid target class image to initialize!".format(batch_index + 1))
continue
else:
target_labels = None
target_images = None
init_adv_image, num_eval = self.initialize(image, target_images, label, target_labels)
init_adv_image, num_eval = self.initialize(image, target_images, label, target_labels, batch_index)
##########################
if init_adv_image is None:
log.info('Initial point not found, {}-th image: image_id: {}, skip this image'.format(
Expand Down Expand Up @@ -1674,12 +1690,13 @@ def parse_args():
parser.add_argument('--target_type', type=str, default='increment',
choices=['random', "load_random", 'least_likely', "increment"])
parser.add_argument('--all-archs', action="store_true")
parser.add_argument('--json-config', type=str, default='../configures/PDA.json',
parser.add_argument('--json-config', type=str, default='{}/TangentAttack-main/configures/PDA.json'.format(PROJECT_PATH),
help='a configures file to be passed in instead of arguments')
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh.'
'If set to True, we will not print anything to screen and only redirect them to log file') # used

parser.add_argument('--load-random-class-image', action='store_true',
help='load a random image from the target class') # npz {"0":, "1": ,"2": }
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
Expand Down