Skip to content

Commit

Permalink
Don't load neighbor model unless needed
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Jan 26, 2024
1 parent 1a8ff53 commit 8640376
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
70 changes: 38 additions & 32 deletions new_mi_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def run_blackbox_attacks(
runnable_attacks.append(a)
attacks = runnable_attacks

neighborhood_attacker = NeighborhoodAttack(config, target_model)
neighborhood_attacker.prepare()
if BlackBoxAttacks.NEIGHBOR in attacks:
neighborhood_attacker = NeighborhoodAttack(config, target_model)
neighborhood_attacker.prepare()

results = defaultdict(list)
for classification in keys_care_about:
Expand All @@ -93,9 +94,8 @@ def run_blackbox_attacks(

# For each batch of data
# TODO: Batch-size isn't really "batching" data - change later
iterator = range(math.ceil(n_samples / batch_size))
if verbose:
iterator = range(math.ceil(n_samples / batch_size))
else:
iterator = tqdm(iterator, desc=f"Computing criterion")

for batch in iterator:
Expand Down Expand Up @@ -240,7 +240,10 @@ def run_blackbox_attacks(

# Update collected scores for each sample with ref-based attack scores
for classification, result in results.items():
for r in tqdm(result, desc="Ref scores"):
itr = result
if verbose:
itr = tqdm(itr, desc="Ref scores")
for r in itr:
ref_model_scores = []
for i, s in enumerate(r["sample"]):
if config.pretokenized:
Expand Down Expand Up @@ -580,32 +583,35 @@ def edit(x, n: int):
x = base_model.tokenizer.decode(x_tok)
return x

if config.load_from_cache:
with open(
f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json",
"r",
) as f:
other_members_data = json.load(f)
n_try = list(other_members_data.keys())
n_trials = len(other_members_data[n_try[0]])
elif config.dump_cache:
# Try out multiple "distances"
n_try = [1, 5, 10, 25, 100]
# With multiple trials
n_trials = 50
other_members_data = {}
for n in tqdm(n_try, "Generating edited members"):
trials = {}
for i in tqdm(range(n_trials)):
trials[i] = [edit(x, n) for x in data_member]
other_members_data[n] = trials
with open(
f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json",
"w",
) as f:
json.dump(other_members_data, f)
print("Data dumped! Please re-run with load_from_cache set to True")
exit(0)
# """
with open(
f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json",
"r",
) as f:
other_members_data = json.load(f)
n_try = list(other_members_data.keys())
n_trials = len(other_members_data[n_try[0]])
# """

"""
# Try out multiple "distances"
n_try = [1, 5, 10, 25, 100]
# With multiple trials
n_trials = 20
other_members_data = {}
for n in tqdm(n_try, "Generating edited members"):
trials = {}
for i in tqdm(range(n_trials)):
trials[i] = [edit(x, n) for x in data_member]
other_members_data[n] = trials
with open(
f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json",
"w",
) as f:
json.dump(other_members_data, f)
print("Data dumped! Please re-run with load_from_cache set to True")
exit(0)
"""
### </LOGIC FOR SPECIFIC EXPERIMENTS>

# Using thresholds returned in blackbox_outputs, compute AUCs and ROC curves for other non-member sources
Expand All @@ -631,7 +637,7 @@ def edit(x, n: int):
)
pbar.update(1)

for attack in other_blackbox_predictions.keys():
for attack in config.blackbox_attacks:
score_dict[attack][n][i] = other_blackbox_predictions[attack]["member"]

pbar.close()
Expand Down
5 changes: 3 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def run_blackbox_attacks(
runnable_attacks.append(a)
attacks = runnable_attacks

neighborhood_attacker = NeighborhoodAttack(config, target_model)
neighborhood_attacker.prepare()
if BlackBoxAttacks.NEIGHBOR in attacks:
neighborhood_attacker = NeighborhoodAttack(config, target_model)
neighborhood_attacker.prepare()

results = defaultdict(list)
for classification in keys_care_about:
Expand Down

0 comments on commit 8640376

Please sign in to comment.