diff --git a/docs/attacks/all_attacks.html b/docs/attacks/all_attacks.html index 8157e78..69dbea4 100644 --- a/docs/attacks/all_attacks.html +++ b/docs/attacks/all_attacks.html @@ -54,6 +54,7 @@

Classes

MIN_K_PLUS_PLUS = "min_k++" # Done NEIGHBOR = "ne" # Done GRADNORM = "gradnorm" # Done + RECALL = "recall" # QUANTILE = "quantile" # Uncomment when tested implementation is available

Ancestors

@@ -83,6 +84,10 @@

Class variables

+
var RECALL
+
+
+
var REFERENCE_BASED
@@ -161,6 +166,7 @@

Subclasses

  • MinKPlusPlusAttack
  • NeighborhoodAttack
  • QuantileAttack
  • +
  • ReCaLLAttack
  • ReferenceAttack
  • ZLIBAttack
  • @@ -214,6 +220,7 @@

    MIN_K
  • MIN_K_PLUS_PLUS
  • NEIGHBOR
  • +
  • RECALL
  • REFERENCE_BASED
  • ZLIB
  • diff --git a/docs/attacks/index.html b/docs/attacks/index.html index a9ed6de..5a47c6a 100644 --- a/docs/attacks/index.html +++ b/docs/attacks/index.html @@ -63,6 +63,10 @@

    Sub-modules

    Implementation of the attack proposed in 'Scalable Membership Inference Attacks via Quantile Regression' https://arxiv.org/pdf/2307.03694.pdf

    +
    mimir.attacks.recall
    +
    +

    ReCaLL Attack: https://github.com/ruoyuxie/recall/

    +
    mimir.attacks.reference

    Reference-based attacks.

    @@ -109,6 +113,7 @@

    Sub-modules

  • mimir.attacks.min_k_plus_plus
  • mimir.attacks.neighborhood
  • mimir.attacks.quantile
  • +
  • mimir.attacks.recall
  • mimir.attacks.reference
  • mimir.attacks.utils
  • mimir.attacks.zlib
  • diff --git a/docs/attacks/recall.html b/docs/attacks/recall.html new file mode 100644 index 0000000..3516b23 --- /dev/null +++ b/docs/attacks/recall.html @@ -0,0 +1,235 @@ + + + + + + +mimir.attacks.recall API documentation + + + + + + + + + + + +
    +
    +
    +

    Module mimir.attacks.recall

    +
    +
    +

    ReCaLL Attack: https://github.com/ruoyuxie/recall/

    +
    +
    +
    +
    +
    +
    +
    +
    +

    Classes

    +
    +
    +class ReCaLLAttack +(config: ExperimentConfig, target_model: Model) +
    +
    +
    +
    + +Expand source code + +
    class ReCaLLAttack(Attack):
    +
    +    #** Note: this is a suboptimal implementation of the ReCaLL attack due to necessary changes made to integrate it alongside the other attacks
    +    #** for a better performing version, please refer to: https://github.com/ruoyuxie/recall 
    +    
    +    def __init__(self, config: ExperimentConfig, target_model: Model):
    +        super().__init__(config, target_model, ref_model = None)
    +        self.prefix = None
    +
    +    @torch.no_grad()
    +    def _attack(self, document, probs, tokens = None, **kwargs):        
    +        recall_dict: dict = kwargs.get("recall_dict", None)
    +
    +        nonmember_prefix = recall_dict.get("prefix")
    +        num_shots = recall_dict.get("num_shots")
    +        avg_length = recall_dict.get("avg_length")
    +
    +        assert nonmember_prefix, "nonmember_prefix should not be None or empty"
    +        assert num_shots, "num_shots should not be None or empty"
    +        assert avg_length, "avg_length should not be None or empty"
    +
    +        lls = self.target_model.get_ll(document, probs = probs, tokens = tokens)
    +        ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document,
    +                                                num_shots = num_shots, avg_length = avg_length,
    +                                                  tokens = tokens)
    +        recall = ll_nonmember / lls
    +
    +
    +        assert not np.isnan(recall)
    +        return recall
    +    
    +    def process_prefix(self, prefix, avg_length, total_shots):
    +        model = self.target_model
    +        tokenizer = self.target_model.tokenizer
    +
    +        if self.prefix is not None:
    +            # We only need to process the prefix once, after that we can just return
    +            return self.prefix
    +
    +        max_length = model.max_length
    +        token_counts = [len(tokenizer.encode(shot)) for shot in prefix]
    +
    +        target_token_count = avg_length
    +        total_tokens = sum(token_counts) + target_token_count
    +        if total_tokens<=max_length:
    +            self.prefix = prefix
    +            return self.prefix
    +        # Determine the maximum number of shots that can fit within the max_length
    +        max_shots = 0
    +        cumulative_tokens = target_token_count
    +        for count in token_counts:
    +            if cumulative_tokens + count <= max_length:
    +                max_shots += 1
    +                cumulative_tokens += count
    +            else:
    +                break
    +        # Truncate the prefix to include only the maximum number of shots
    +        truncated_prefix = prefix[-max_shots:]
    +        print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""")
    +        self.prefix = truncated_prefix
    +        return self.prefix
    +    
    +    def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None):
    +        assert nonmember_prefix, "nonmember_prefix should not be None or empty"
    +
    +        model = self.target_model
    +        tokenizer = self.target_model.tokenizer
    +
    +        if tokens is None:
    +            target_encodings = tokenizer(text=text, return_tensors="pt")
    +        else:
    +            target_encodings = tokens
    +
    +        processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots)
    +        input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt")
    +
    +        prefix_ids = input_encodings.input_ids.to(model.device)
    +        text_ids = target_encodings.input_ids.to(model.device)
    +
    +        max_length = model.max_length
    +
    +        if prefix_ids.size(1) >= max_length:
    +            raise ValueError("Prefix length exceeds or equals the model's maximum context window.")
    +
    +        labels = torch.cat((prefix_ids, text_ids), dim=1)
    +        total_length = labels.size(1)
    +
    +        total_loss = 0
    +        total_tokens = 0
    +        with torch.no_grad():
    +            for i in range(0, total_length, max_length):
    +                begin_loc = i
    +                end_loc = min(i + max_length, total_length)
    +                trg_len = end_loc - begin_loc
    +                
    +                input_ids = labels[:, begin_loc:end_loc].to(model.device)
    +                target_ids = input_ids.clone()
    +                
    +                if begin_loc < prefix_ids.size(1):
    +                    prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length)
    +                    target_ids[:, :prefix_overlap] = -100
    +                
    +                if end_loc > total_length - text_ids.size(1):
    +                    target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length)
    +                    target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:]
    +                
    +                if torch.all(target_ids == -100):
    +                    continue
    +                
    +                outputs = model.model(input_ids, labels=target_ids)
    +                loss = outputs.loss
    +                if torch.isnan(loss):
    +                    print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}")
    +                    continue
    +                non_masked_tokens = (target_ids != -100).sum().item()
    +                total_loss += loss.item() * non_masked_tokens
    +                total_tokens += non_masked_tokens
    +
    +        average_loss = total_loss / total_tokens if total_tokens > 0 else 0
    +        return -average_loss
    +
    +

    Ancestors

    + +

    Methods

    +
    +
    +def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None) +
    +
    +
    +
    +
    +def process_prefix(self, prefix, avg_length, total_shots) +
    +
    +
    +
    +
    +

    Inherited members

    + +
    +
    +
    +
    + +
    + + + diff --git a/docs/config.html b/docs/config.html index e2cb034..fe4b670 100644 --- a/docs/config.html +++ b/docs/config.html @@ -132,7 +132,7 @@

    Class variables

    class ExperimentConfig -(experiment_name: str, base_model: str, dataset_member: str, dataset_nonmember: str, output_name: str = None, dataset_nonmember_other_sources: Optional[List[str]] = <factory>, pretokenized: Optional[bool] = False, revision: Optional[str] = None, presampled_dataset_member: Optional[str] = None, presampled_dataset_nonmember: Optional[str] = None, token_frequency_map: Optional[str] = None, dataset_key: Optional[str] = None, specific_source: Optional[str] = None, full_doc: Optional[bool] = False, max_substrs: Optional[int] = 20, dump_cache: Optional[bool] = False, load_from_cache: Optional[bool] = False, load_from_hf: Optional[bool] = True, blackbox_attacks: Optional[List[str]] = <factory>, tokenization_attack: Optional[bool] = False, quantile_attack: Optional[bool] = False, n_samples: Optional[int] = 200, max_tokens: Optional[int] = 512, max_data: Optional[int] = 5000, min_words: Optional[int] = 100, max_words: Optional[int] = 200, max_words_cutoff: Optional[bool] = True, batch_size: Optional[int] = 50, chunk_size: Optional[int] = 20, scoring_model_name: Optional[str] = None, top_k: Optional[int] = 40, do_top_k: Optional[bool] = False, top_p: Optional[float] = 0.96, do_top_p: Optional[bool] = False, pre_perturb_pct: Optional[float] = 0.0, pre_perturb_span_length: Optional[int] = 5, tok_by_tok: Optional[bool] = False, fpr_list: Optional[List[float]] = <factory>, random_seed: Optional[int] = 0, ref_config: Optional[ReferenceConfig] = None, neighborhood_config: Optional[NeighborhoodConfig] = None, env_config: Optional[EnvironmentConfig] = None, openai_config: Optional[OpenAIConfig] = None) +(experiment_name: str, base_model: str, dataset_member: str, dataset_nonmember: str, output_name: str = None, dataset_nonmember_other_sources: Optional[List[str]] = <factory>, pretokenized: Optional[bool] = False, revision: Optional[str] = None, presampled_dataset_member: Optional[str] = None, presampled_dataset_nonmember: Optional[str] = None, token_frequency_map: Optional[str] = None, dataset_key: Optional[str] = None, specific_source: Optional[str] = None, full_doc: Optional[bool] = False, max_substrs: Optional[int] = 20, dump_cache: Optional[bool] = False, load_from_cache: Optional[bool] = False, load_from_hf: Optional[bool] = True, blackbox_attacks: Optional[List[str]] = <factory>, tokenization_attack: Optional[bool] = False, quantile_attack: Optional[bool] = False, n_samples: Optional[int] = 200, max_tokens: Optional[int] = 512, max_data: Optional[int] = 5000, min_words: Optional[int] = 100, max_words: Optional[int] = 200, max_words_cutoff: Optional[bool] = True, batch_size: Optional[int] = 50, chunk_size: Optional[int] = 20, scoring_model_name: Optional[str] = None, top_k: Optional[int] = 40, do_top_k: Optional[bool] = False, top_p: Optional[float] = 0.96, do_top_p: Optional[bool] = False, pre_perturb_pct: Optional[float] = 0.0, pre_perturb_span_length: Optional[int] = 5, tok_by_tok: Optional[bool] = False, fpr_list: Optional[List[float]] = <factory>, random_seed: Optional[int] = 0, ref_config: Optional[ReferenceConfig] = None, recall_config: Optional[ReCaLLConfig] = None, neighborhood_config: Optional[NeighborhoodConfig] = None, env_config: Optional[EnvironmentConfig] = None, openai_config: Optional[OpenAIConfig] = None)

    Config for attacks

    @@ -231,6 +231,8 @@

    Class variables

    """Random seed""" ref_config: Optional[ReferenceConfig] = None """Reference model config""" + recall_config: Optional[ReCaLLConfig] = None + """ReCaLL attack config""" neighborhood_config: Optional[NeighborhoodConfig] = None """Neighborhood attack config""" env_config: Optional[EnvironmentConfig] = None @@ -401,6 +403,10 @@

    Class variables

    Random seed

    +
    var recall_config : Optional[ReCaLLConfig]
    +
    +

    ReCaLL attack config

    +
    var ref_config : Optional[ReferenceConfig]

    Reference model config

    @@ -605,6 +611,41 @@

    Class variables

    +
    +class ReCaLLConfig +(num_shots: Optional[int] = 1) +
    +
    +

    Config for ReCaLL attack

    +
    + +Expand source code + +
    @dataclass
    +class ReCaLLConfig(Serializable):
    +    """
    +    Config for ReCaLL attack
    +    """
    +    num_shots: Optional[int] = 1
    +    """Number of shots for ReCaLL Attacks"""
    +
    +

    Ancestors

    + +

    Class variables

    +
    +
    var decode_into_subclasses : ClassVar[bool]
    +
    +
    +
    +
    var num_shots : Optional[int]
    +
    +

    Number of shots for ReCaLL Attacks

    +
    +
    +
    class ReferenceConfig (models: List[str]) @@ -714,6 +755,7 @@

    pretokenized
  • quantile_attack
  • random_seed
  • +
  • recall_config
  • ref_config
  • revision
  • scoring_model_name
  • @@ -755,6 +797,13 @@

  • +

    ReCaLLConfig

    + +
  • +
  • ReferenceConfig