Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of PRBCD and GRBCD attacks #9

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
93 changes: 93 additions & 0 deletions examples/attack/targeted/rbcd_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os.path as osp

import torch
import torch_geometric.transforms as T

from greatx.attack.targeted import GRBCDAttack, PRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint
from greatx.utils import mark, split_nodes

dataset = 'Cora'
root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data')
dataset = GraphDataset(root=root, name=dataset,
transform=T.LargestConnectedComponents())

data = dataset[0]
splits = split_nodes(data.y, random_state=15)

num_features = data.x.size(-1)
num_classes = data.y.max().item() + 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ================================================================== #
# Attack Setting #
# ================================================================== #
target = 1 # target node to attack
target_label = data.y[target].item()

# ================================================================== #
# Before Attack #
# ================================================================== #
trainer_before = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_before.pth', monitor='val_acc')
trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
output = trainer_before.predict(data, mask=target)
print("Before attack:")
print(mark(output, target_label))

# ================================================================== #
# Attacking (PRBCDAttack) #
# ================================================================== #
attacker = PRBCDAttack(data, device=device)
attacker.setup_surrogate(trainer_before.model)
attacker.reset()
attacker.attack(target)

# ================================================================== #
# After evasion Attack #
# ================================================================== #
output = trainer_before.predict(attacker.data(), mask=target)
print("After evasion attack:")
print(mark(output, target_label))

# ================================================================== #
# After poisoning Attack #
# ================================================================== #
trainer_after = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_after.pth', monitor='val_acc')
trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
output = trainer_after.predict(attacker.data(), mask=target)
print("After poisoning attack:")
print(mark(output, target_label))

# ================================================================== #
# Attacking (GRBCDAttack) #
# ================================================================== #
attacker = GRBCDAttack(data, device=device)
attacker.setup_surrogate(trainer_before.model)
attacker.reset()
attacker.attack(target)

# ================================================================== #
# After evasion Attack #
# ================================================================== #
output = trainer_before.predict(attacker.data(), mask=target)
print("After evasion attack:")
print(mark(output, target_label))

# ================================================================== #
# After poisoning Attack #
# ================================================================== #
trainer_after = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_after.pth', monitor='val_acc')
trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
output = trainer_after.predict(attacker.data(), mask=target)
print("After poisoning attack:")
print(mark(output, target_label))
90 changes: 90 additions & 0 deletions examples/attack/untargeted/rbcd_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os.path as osp

import torch
import torch_geometric.transforms as T

from greatx.attack.untargeted import GRBCDAttack, PRBCDAttack
from greatx.datasets import GraphDataset
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint
from greatx.utils import split_nodes

dataset = 'Cora'
root = osp.join(osp.dirname(osp.realpath(__file__)), '../../..', 'data')
dataset = GraphDataset(root=root, name=dataset,
transform=T.LargestConnectedComponents())

data = dataset[0]
splits = split_nodes(data.y, random_state=15)

num_features = data.x.size(-1)
num_classes = data.y.max().item() + 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ================================================================== #
# Before Attack #
# ================================================================== #
trainer_before = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_before.pth', monitor='val_acc')
trainer_before.fit(data, mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
logs = trainer_before.evaluate(data, splits.test_nodes)
print(f"Before attack\n {logs}")

# ================================================================== #
# Attacking (PRBCDAttack) #
# ================================================================== #
attacker = PRBCDAttack(data, device=device)
attacker.setup_surrogate(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it confusing to call it surrogate. Rather call it victim, target, or similar. I think it should not be the default to use surrogates, but rather the exception. See our paper "Are Defenses for GNNs robust?"

trainer_before.model,
victim_nodes=splits.test_nodes,
# set True to use ground-truth labels
ground_truth=False,
)
attacker.reset()
attacker.attack(0.05)

# ================================================================== #
# After evasion Attack #
# ================================================================== #
logs = trainer_before.evaluate(attacker.data(), splits.test_nodes)
print(f"After evasion attack\n {logs}")
# ================================================================== #
# After poisoning Attack #

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make clear that these poisoning attacks are heuristic in the sense that you are transferring the evasion attack.

# ================================================================== #
trainer_after = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_after.pth', monitor='val_acc')
trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
logs = trainer_after.evaluate(attacker.data(), splits.test_nodes)
print(f"After poisoning attack\n {logs}")

# ================================================================== #
# Attacking (GRBCDAttack) #
# ================================================================== #
attacker = GRBCDAttack(data, device=device)
attacker.setup_surrogate(
trainer_before.model,
victim_nodes=splits.test_nodes,
# set True to use ground-truth labels
ground_truth=False,
)
attacker.reset()
attacker.attack(0.05)

# ================================================================== #
# After evasion Attack #
# ================================================================== #
logs = trainer_before.evaluate(attacker.data(), splits.test_nodes)
print(f"After evasion attack\n {logs}")
# ================================================================== #
# After poisoning Attack #
# ================================================================== #
trainer_after = Trainer(GCN(num_features, num_classes), device=device)
ckp = ModelCheckpoint('model_after.pth', monitor='val_acc')
trainer_after.fit(attacker.data(), mask=(splits.train_nodes, splits.val_nodes),
callbacks=[ckp])
logs = trainer_after.evaluate(attacker.data(), splits.test_nodes)
print(f"After poisoning attack\n {logs}")
5 changes: 4 additions & 1 deletion greatx/attack/targeted/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .targeted_attacker import TargetedAttacker
from .dice_attack import DICEAttack
from .fg_attack import FGAttack
from .gf_attack import GFAttack
Expand All @@ -6,7 +7,7 @@
from .pgd_attack import PGDAttack
from .random_attack import RandomAttack
from .sg_attack import SGAttack
from .targeted_attacker import TargetedAttacker
from .rbcd_attack import PRBCDAttack, GRBCDAttack

classes = __all__ = [
'TargetedAttacker',
Expand All @@ -18,4 +19,6 @@
'Nettack',
'GFAttack',
'PGDAttack',
'PRBCDAttack',
'GRBCDAttack',
]
Loading