From 86403761a200ffc366f1498240f0f131997174ad Mon Sep 17 00:00:00 2001 From: Anshuman Suri Date: Fri, 26 Jan 2024 08:49:49 -0500 Subject: [PATCH] Don't load neighbor model unless needed --- new_mi_experiment.py | 70 ++++++++++++++++++++++++-------------------- run.py | 5 ++-- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/new_mi_experiment.py b/new_mi_experiment.py index 2c520ad..8014cde 100644 --- a/new_mi_experiment.py +++ b/new_mi_experiment.py @@ -75,8 +75,9 @@ def run_blackbox_attacks( runnable_attacks.append(a) attacks = runnable_attacks - neighborhood_attacker = NeighborhoodAttack(config, target_model) - neighborhood_attacker.prepare() + if BlackBoxAttacks.NEIGHBOR in attacks: + neighborhood_attacker = NeighborhoodAttack(config, target_model) + neighborhood_attacker.prepare() results = defaultdict(list) for classification in keys_care_about: @@ -93,9 +94,8 @@ def run_blackbox_attacks( # For each batch of data # TODO: Batch-size isn't really "batching" data - change later + iterator = range(math.ceil(n_samples / batch_size)) if verbose: - iterator = range(math.ceil(n_samples / batch_size)) - else: iterator = tqdm(iterator, desc=f"Computing criterion") for batch in iterator: @@ -240,7 +240,10 @@ def run_blackbox_attacks( # Update collected scores for each sample with ref-based attack scores for classification, result in results.items(): - for r in tqdm(result, desc="Ref scores"): + itr = result + if verbose: + itr = tqdm(itr, desc="Ref scores") + for r in itr: ref_model_scores = [] for i, s in enumerate(r["sample"]): if config.pretokenized: @@ -580,32 +583,35 @@ def edit(x, n: int): x = base_model.tokenizer.decode(x_tok) return x - if config.load_from_cache: - with open( - f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json", - "r", - ) as f: - other_members_data = json.load(f) - n_try = list(other_members_data.keys()) - n_trials = len(other_members_data[n_try[0]]) - elif config.dump_cache: - # Try out multiple "distances" - n_try = [1, 5, 10, 25, 100] - # With multiple trials - n_trials = 50 - other_members_data = {} - for n in tqdm(n_try, "Generating edited members"): - trials = {} - for i in tqdm(range(n_trials)): - trials[i] = [edit(x, n) for x in data_member] - other_members_data[n] = trials - with open( - f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json", - "w", - ) as f: - json.dump(other_members_data, f) - print("Data dumped! Please re-run with load_from_cache set to True") - exit(0) + # """ + with open( + f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json", + "r", + ) as f: + other_members_data = json.load(f) + n_try = list(other_members_data.keys()) + n_trials = len(other_members_data[n_try[0]]) + # """ + + """ + # Try out multiple "distances" + n_try = [1, 5, 10, 25, 100] + # With multiple trials + n_trials = 20 + other_members_data = {} + for n in tqdm(n_try, "Generating edited members"): + trials = {} + for i in tqdm(range(n_trials)): + trials[i] = [edit(x, n) for x in data_member] + other_members_data[n] = trials + with open( + f"/p/distinf/uw_llm_collab/edit_distance_members/{config.specific_source}.json", + "w", + ) as f: + json.dump(other_members_data, f) + print("Data dumped! Please re-run with load_from_cache set to True") + exit(0) + """ ### # Using thresholds returned in blackbox_outputs, compute AUCs and ROC curves for other non-member sources @@ -631,7 +637,7 @@ def edit(x, n: int): ) pbar.update(1) - for attack in other_blackbox_predictions.keys(): + for attack in config.blackbox_attacks: score_dict[attack][n][i] = other_blackbox_predictions[attack]["member"] pbar.close() diff --git a/run.py b/run.py index fab0a46..849c1c6 100644 --- a/run.py +++ b/run.py @@ -73,8 +73,9 @@ def run_blackbox_attacks( runnable_attacks.append(a) attacks = runnable_attacks - neighborhood_attacker = NeighborhoodAttack(config, target_model) - neighborhood_attacker.prepare() + if BlackBoxAttacks.NEIGHBOR in attacks: + neighborhood_attacker = NeighborhoodAttack(config, target_model) + neighborhood_attacker.prepare() results = defaultdict(list) for classification in keys_care_about: