diff --git a/run.py b/run.py index b1f571e..68abc0f 100644 --- a/run.py +++ b/run.py @@ -98,9 +98,10 @@ def get_mia_scores( neighbors = data[f"neighbors"] print("Loaded neighbors from cache!") - collected_neighbors = { - n_perturbation: [] for n_perturbation in n_perturbation_list - } + if neigh_config and neigh_config.dump_cache: + collected_neighbors = { + n_perturbation: [] for n_perturbation in n_perturbation_list + } recall_config = config.recall_config if recall_config: @@ -129,8 +130,9 @@ def get_mia_scores( detokenized_sample = [target_model.tokenizer.decode(s) for s in sample] sample_information["detokenized"] = detokenized_sample + if neigh_config and neigh_config.dump_cache: + neighbors_within = {n_perturbation: [] for n_perturbation in n_perturbation_list} # For each substring - neighbors_within = {n_perturbation: [] for n_perturbation in n_perturbation_list} for i, substr in enumerate(sample): # compute token probabilities for sample s_tk_probs, s_all_probs = (