From 2772f727d42cd08e93f81048d5b01877b34d1bb1 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 16 Sep 2024 13:04:55 +0000 Subject: [PATCH] Update documentation --- docs/attacks/all_attacks.html | 151 +--- docs/attacks/attack_utils.html | 350 +------- docs/attacks/gradnorm.html | 80 +- docs/attacks/index.html | 36 +- docs/attacks/loss.html | 53 +- docs/attacks/min_k.html | 71 +- docs/attacks/min_k_plus_plus.html | 71 +- docs/attacks/neighborhood.html | 1134 +------------------------ docs/attacks/quantile.html | 193 +---- docs/attacks/reference.html | 59 +- docs/attacks/utils.html | 78 +- docs/attacks/zlib.html | 68 +- docs/config.html | 251 +----- docs/custom_datasets.html | 418 +-------- docs/data_utils.html | 688 +-------------- docs/index.html | 28 +- docs/models.html | 1314 +---------------------------- docs/plot_utils.html | 288 +------ docs/utils.html | 139 +-- 19 files changed, 291 insertions(+), 5179 deletions(-) diff --git a/docs/attacks/all_attacks.html b/docs/attacks/all_attacks.html index a6a4391..8157e78 100644 --- a/docs/attacks/all_attacks.html +++ b/docs/attacks/all_attacks.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.all_attacks API documentation - - - - - - + + + + + + - - + +
@@ -23,81 +26,6 @@

Module mimir.attacks.all_attacks

Enum class for attacks. Also contains the base attack class.

-
- -Expand source code - -
"""
-    Enum class for attacks. Also contains the base attack class.
-"""
-
-from enum import Enum
-from mimir.models import Model
-
-
-# Attack definitions
-class AllAttacks(str, Enum):
-    LOSS = "loss" # Done
-    REFERENCE_BASED = "ref" # Done
-    ZLIB = "zlib" # Done
-    MIN_K = "min_k" # Done
-    MIN_K_PLUS_PLUS = "min_k++" # Done
-    NEIGHBOR = "ne" # Done
-    GRADNORM = "gradnorm" # Done
-    # QUANTILE = "quantile" # Uncomment when tested implementation is available
-
-
-# Base attack class
-class Attack:
-    def __init__(self, config, target_model: Model, ref_model: Model = None, is_blackbox: bool = True):
-        self.config = config
-        self.target_model = target_model
-        self.ref_model = ref_model
-        self.is_loaded = False
-        self.is_blackbox = is_blackbox
-
-    def load(self):
-        """
-        Any attack-specific steps (one-time) preparation
-        """
-        if self.ref_model is not None:
-            self.ref_model.load()
-            self.is_loaded = True
-
-    def unload(self):
-        if self.ref_model is not None:
-            self.ref_model.unload()
-            self.is_loaded = False
-
-    def _attack(self, document, probs, tokens=None, **kwargs):
-        """
-        Actual logic for attack. 
-        """
-        raise NotImplementedError("Attack must implement attack()")
-
-    def attack(self, document, probs, **kwargs):
-        """
-        Score a document using the attack's scoring function. Calls self._attack
-        """
-        # Load attack if not loaded yet
-        if not self.is_loaded:
-            self.load()
-            self.is_loaded = True
-
-        detokenized_sample = kwargs.get("detokenized_sample", None)
-        if self.config.pretokenized and detokenized_sample is None:
-            raise ValueError("detokenized_sample must be provided")
-
-        score = (
-            self._attack(document, probs=probs, **kwargs)
-            if not self.config.pretokenized
-            else self._attack(
-                detokenized_sample, tokens=document, probs=probs, **kwargs
-            )
-        )
-
-        return score
-
@@ -243,66 +171,18 @@

Methods

Score a document using the attack's scoring function. Calls self._attack

-
- -Expand source code - -
def attack(self, document, probs, **kwargs):
-    """
-    Score a document using the attack's scoring function. Calls self._attack
-    """
-    # Load attack if not loaded yet
-    if not self.is_loaded:
-        self.load()
-        self.is_loaded = True
-
-    detokenized_sample = kwargs.get("detokenized_sample", None)
-    if self.config.pretokenized and detokenized_sample is None:
-        raise ValueError("detokenized_sample must be provided")
-
-    score = (
-        self._attack(document, probs=probs, **kwargs)
-        if not self.config.pretokenized
-        else self._attack(
-            detokenized_sample, tokens=document, probs=probs, **kwargs
-        )
-    )
-
-    return score
-
def load(self)

Any attack-specific steps (one-time) preparation

-
- -Expand source code - -
def load(self):
-    """
-    Any attack-specific steps (one-time) preparation
-    """
-    if self.ref_model is not None:
-        self.ref_model.load()
-        self.is_loaded = True
-
def unload(self)
-
- -Expand source code - -
def unload(self):
-    if self.ref_model is not None:
-        self.ref_model.unload()
-        self.is_loaded = False
-
@@ -315,7 +195,6 @@

Methods

MIMIR -

Index

    @@ -353,7 +232,7 @@

    -

    Generated by pdoc 0.10.0.

    +

    Generated by pdoc 0.11.1.

    - \ No newline at end of file + diff --git a/docs/attacks/attack_utils.html b/docs/attacks/attack_utils.html index cadafca..31682a8 100644 --- a/docs/attacks/attack_utils.html +++ b/docs/attacks/attack_utils.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.attack_utils API documentation - - - - - - + + + + + + - - + +
    @@ -23,164 +26,6 @@

    Module mimir.attacks.attack_utils

    Utility functions for attacks

    -
    - -Expand source code - -
    """
    -    Utility functions for attacks
    -"""
    -from typing import List
    -import torch
    -from collections import Counter
    -import math
    -import numpy as np
    -from sklearn.metrics import roc_curve, auc, precision_recall_curve
    -from scipy.stats import bootstrap
    -
    -
    -def count_masks(texts):
    -    return [
    -        len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts
    -    ]
    -
    -
    -def apply_extracted_fills(masked_texts: List[str], extracted_fills):
    -    # split masked text into tokens, only splitting on spaces (not newlines)
    -    tokens = [x.split(" ") for x in masked_texts]
    -
    -    n_expected = count_masks(masked_texts)
    -
    -    # replace each mask token with the corresponding fill
    -    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
    -        if len(fills) < n:
    -            tokens[idx] = []
    -        else:
    -            for fill_idx in range(n):
    -                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]
    -
    -    # join tokens back into text
    -    texts = [" ".join(x) for x in tokens]
    -    return texts
    -
    -
    -def f1_score(prediction, ground_truth):
    -    """
    -        Compute F1 score for given prediction and ground truth.
    -    """
    -    common = Counter(prediction) & Counter(ground_truth)
    -    num_same = sum(common.values())
    -    if num_same == 0:
    -        return 0, 0, 0
    -    precision = 1.0 * num_same / len(prediction)
    -    recall = 1.0 * num_same / len(ground_truth)
    -    f1 = (2 * precision * recall) / (precision + recall)
    -    print(num_same, f1, precision, recall)
    -    return f1, precision, recall
    -
    -
    -def get_auc_from_thresholds(preds_member, preds_nonmember, thresholds):
    -    """
    -    Compute FPRs and TPRs corresponding to given thresholds
    -    """
    -    tpr, fpr = [], []
    -    for threshold in thresholds:
    -        tp = np.sum(preds_nonmember >= threshold)
    -        fn = np.sum(preds_nonmember < threshold)
    -        fp = np.sum(preds_member >= threshold)
    -        tn = np.sum(preds_member < threshold)
    -
    -        tpr.append(tp / (tp + fn))
    -        fpr.append(fp / (fp + tn))
    -    
    -    tpr = np.array(tpr)
    -    fpr = np.array(fpr)
    -    roc_auc = auc(fpr, tpr)
    -    return roc_auc
    -
    -
    -def get_roc_metrics(
    -    preds_member,
    -    preds_nonmember,
    -    perform_bootstrap: bool = False,
    -    return_thresholds: bool = False,
    -):  # fpr_list,
    -    preds_member_ = filter_out_nan(preds_member)
    -    preds_nonmember_ = filter_out_nan(preds_nonmember)
    -    total_preds = preds_member_ + preds_nonmember_
    -    # While roc_auc is unaffected by which class we consider
    -    # positive/negative, the TPR@lowFPR calculation is.
    -    # Make sure the members are positive class (larger values, so negate the raw MIA scores)
    -    total_preds = np.array(total_preds) * -1
    -    # Assign label '0' to members for computation, since sklearn
    -    # expectes label '0' data to have lower values to get assigned that label
    -    # which is true for our attacks (lower loss for members, e.g.)
    -    total_labels = [1] * len(preds_member_) + [0] * len(preds_nonmember_)
    -    fpr, tpr, thresholds = roc_curve(total_labels, total_preds)
    -
    -    roc_auc = auc(fpr, tpr)
    -    # tpr_at_low_fpr = {upper_bound: tpr[np.where(np.array(fpr) < upper_bound)[0][-1]] for upper_bound in fpr_list}
    -
    -    if perform_bootstrap:
    -
    -        def roc_auc_statistic(preds, labels):
    -            in_preds = [pred for pred, label in zip(preds, labels) if label == 1]
    -            out_preds = [pred for pred, label in zip(preds, labels) if label == 0]
    -            _, _, roc_auc = get_roc_metrics(in_preds, out_preds)
    -            return roc_auc
    -
    -        auc_roc_res = bootstrap(
    -            (total_preds, total_labels),
    -            roc_auc_statistic,
    -            n_resamples=1000,
    -            paired=True,
    -        )
    -
    -        # tpr_at_low_fpr_res = {}
    -        # for ub in fpr_list:
    -        #     def tpr_at_fpr_statistic(preds, labels):
    -        #         in_preds = [pred for pred, label in zip(preds, labels) if label == 1]
    -        #         out_preds = [pred for pred, label in zip(preds, labels) if label == 0]
    -        #         _, _, _, tpr_at_low_fpr = get_roc_metrics(in_preds, out_preds, [ub])
    -        #         return tpr_at_low_fpr[ub]
    -
    -        #     tpr_at_low_fpr_res[ub] = bootstrap((total_preds, total_labels), tpr_at_fpr_statistic, n_resamples=1000, paired=True)
    -
    -        if return_thresholds:
    -            return (
    -                fpr.tolist(),
    -                tpr.tolist(),
    -                float(roc_auc),
    -                auc_roc_res,
    -                thresholds.tolist(),
    -            )
    -        return (
    -            fpr.tolist(),
    -            tpr.tolist(),
    -            float(roc_auc),
    -            auc_roc_res,
    -        )  # tpr_at_low_fpr, tpr_at_low_fpr_res
    -
    -    if return_thresholds:
    -        return fpr.tolist(), tpr.tolist(), float(roc_auc), thresholds.tolist()
    -    return fpr.tolist(), tpr.tolist(), float(roc_auc)  # , tpr_at_low_fpr
    -
    -
    -def get_precision_recall_metrics(preds_member, preds_nonmember):
    -    preds_member_ = filter_out_nan(preds_member)
    -    preds_nonmember_ = filter_out_nan(preds_nonmember)
    -    total_preds = preds_member_ + preds_nonmember_
    -
    -    total_labels = [0] * len(preds_member_) + [1] * len(preds_nonmember_)
    -
    -    precision, recall, _ = precision_recall_curve(total_labels, total_preds)
    -    pr_auc = auc(recall, precision)
    -    return precision.tolist(), recall.tolist(), float(pr_auc)
    -
    -
    -def filter_out_nan(x):
    -    return [element for element in x if not math.isnan(element)]
    -
    @@ -194,206 +39,42 @@

    Functions

    -
    - -Expand source code - -
    def apply_extracted_fills(masked_texts: List[str], extracted_fills):
    -    # split masked text into tokens, only splitting on spaces (not newlines)
    -    tokens = [x.split(" ") for x in masked_texts]
    -
    -    n_expected = count_masks(masked_texts)
    -
    -    # replace each mask token with the corresponding fill
    -    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
    -        if len(fills) < n:
    -            tokens[idx] = []
    -        else:
    -            for fill_idx in range(n):
    -                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]
    -
    -    # join tokens back into text
    -    texts = [" ".join(x) for x in tokens]
    -    return texts
    -
    def count_masks(texts)
    -
    - -Expand source code - -
    def count_masks(texts):
    -    return [
    -        len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts
    -    ]
    -
    def f1_score(prediction, ground_truth)

    Compute F1 score for given prediction and ground truth.

    -
    - -Expand source code - -
    def f1_score(prediction, ground_truth):
    -    """
    -        Compute F1 score for given prediction and ground truth.
    -    """
    -    common = Counter(prediction) & Counter(ground_truth)
    -    num_same = sum(common.values())
    -    if num_same == 0:
    -        return 0, 0, 0
    -    precision = 1.0 * num_same / len(prediction)
    -    recall = 1.0 * num_same / len(ground_truth)
    -    f1 = (2 * precision * recall) / (precision + recall)
    -    print(num_same, f1, precision, recall)
    -    return f1, precision, recall
    -
    def filter_out_nan(x)
    -
    - -Expand source code - -
    def filter_out_nan(x):
    -    return [element for element in x if not math.isnan(element)]
    -
    def get_auc_from_thresholds(preds_member, preds_nonmember, thresholds)

    Compute FPRs and TPRs corresponding to given thresholds

    -
    - -Expand source code - -
    def get_auc_from_thresholds(preds_member, preds_nonmember, thresholds):
    -    """
    -    Compute FPRs and TPRs corresponding to given thresholds
    -    """
    -    tpr, fpr = [], []
    -    for threshold in thresholds:
    -        tp = np.sum(preds_nonmember >= threshold)
    -        fn = np.sum(preds_nonmember < threshold)
    -        fp = np.sum(preds_member >= threshold)
    -        tn = np.sum(preds_member < threshold)
    -
    -        tpr.append(tp / (tp + fn))
    -        fpr.append(fp / (fp + tn))
    -    
    -    tpr = np.array(tpr)
    -    fpr = np.array(fpr)
    -    roc_auc = auc(fpr, tpr)
    -    return roc_auc
    -
    def get_precision_recall_metrics(preds_member, preds_nonmember)
    -
    - -Expand source code - -
    def get_precision_recall_metrics(preds_member, preds_nonmember):
    -    preds_member_ = filter_out_nan(preds_member)
    -    preds_nonmember_ = filter_out_nan(preds_nonmember)
    -    total_preds = preds_member_ + preds_nonmember_
    -
    -    total_labels = [0] * len(preds_member_) + [1] * len(preds_nonmember_)
    -
    -    precision, recall, _ = precision_recall_curve(total_labels, total_preds)
    -    pr_auc = auc(recall, precision)
    -    return precision.tolist(), recall.tolist(), float(pr_auc)
    -
    def get_roc_metrics(preds_member, preds_nonmember, perform_bootstrap: bool = False, return_thresholds: bool = False)
    -
    - -Expand source code - -
    def get_roc_metrics(
    -    preds_member,
    -    preds_nonmember,
    -    perform_bootstrap: bool = False,
    -    return_thresholds: bool = False,
    -):  # fpr_list,
    -    preds_member_ = filter_out_nan(preds_member)
    -    preds_nonmember_ = filter_out_nan(preds_nonmember)
    -    total_preds = preds_member_ + preds_nonmember_
    -    # While roc_auc is unaffected by which class we consider
    -    # positive/negative, the TPR@lowFPR calculation is.
    -    # Make sure the members are positive class (larger values, so negate the raw MIA scores)
    -    total_preds = np.array(total_preds) * -1
    -    # Assign label '0' to members for computation, since sklearn
    -    # expectes label '0' data to have lower values to get assigned that label
    -    # which is true for our attacks (lower loss for members, e.g.)
    -    total_labels = [1] * len(preds_member_) + [0] * len(preds_nonmember_)
    -    fpr, tpr, thresholds = roc_curve(total_labels, total_preds)
    -
    -    roc_auc = auc(fpr, tpr)
    -    # tpr_at_low_fpr = {upper_bound: tpr[np.where(np.array(fpr) < upper_bound)[0][-1]] for upper_bound in fpr_list}
    -
    -    if perform_bootstrap:
    -
    -        def roc_auc_statistic(preds, labels):
    -            in_preds = [pred for pred, label in zip(preds, labels) if label == 1]
    -            out_preds = [pred for pred, label in zip(preds, labels) if label == 0]
    -            _, _, roc_auc = get_roc_metrics(in_preds, out_preds)
    -            return roc_auc
    -
    -        auc_roc_res = bootstrap(
    -            (total_preds, total_labels),
    -            roc_auc_statistic,
    -            n_resamples=1000,
    -            paired=True,
    -        )
    -
    -        # tpr_at_low_fpr_res = {}
    -        # for ub in fpr_list:
    -        #     def tpr_at_fpr_statistic(preds, labels):
    -        #         in_preds = [pred for pred, label in zip(preds, labels) if label == 1]
    -        #         out_preds = [pred for pred, label in zip(preds, labels) if label == 0]
    -        #         _, _, _, tpr_at_low_fpr = get_roc_metrics(in_preds, out_preds, [ub])
    -        #         return tpr_at_low_fpr[ub]
    -
    -        #     tpr_at_low_fpr_res[ub] = bootstrap((total_preds, total_labels), tpr_at_fpr_statistic, n_resamples=1000, paired=True)
    -
    -        if return_thresholds:
    -            return (
    -                fpr.tolist(),
    -                tpr.tolist(),
    -                float(roc_auc),
    -                auc_roc_res,
    -                thresholds.tolist(),
    -            )
    -        return (
    -            fpr.tolist(),
    -            tpr.tolist(),
    -            float(roc_auc),
    -            auc_roc_res,
    -        )  # tpr_at_low_fpr, tpr_at_low_fpr_res
    -
    -    if return_thresholds:
    -        return fpr.tolist(), tpr.tolist(), float(roc_auc), thresholds.tolist()
    -    return fpr.tolist(), tpr.tolist(), float(roc_auc)  # , tpr_at_low_fpr
    -
    @@ -406,7 +87,6 @@

    Functions

    MIMIR -

    Index

      @@ -431,7 +111,7 @@

      Index

      - \ No newline at end of file + diff --git a/docs/attacks/gradnorm.html b/docs/attacks/gradnorm.html index 73ee0ae..2a4ef60 100644 --- a/docs/attacks/gradnorm.html +++ b/docs/attacks/gradnorm.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.gradnorm API documentation - - - - - - + + + + + + - - + +
      @@ -23,58 +26,6 @@

      Module mimir.attacks.gradnorm

      Gradient-norm attack. Proposed for MIA in multiple settings, and particularly experimented for pre-training data and LLMs in https://arxiv.org/abs/2402.17012

      -
      - -Expand source code - -
      """
      -    Gradient-norm attack. Proposed for MIA in multiple settings, and particularly experimented for pre-training data and LLMs in https://arxiv.org/abs/2402.17012
      -"""
      -
      -import torch as ch
      -import numpy as np
      -from mimir.attacks.all_attacks import Attack
      -from mimir.models import Model
      -from mimir.config import ExperimentConfig
      -
      -
      -class GradNormAttack(Attack):
      -    def __init__(self, config: ExperimentConfig, model: Model):
      -        super().__init__(config, model, ref_model=None, is_blackbox=False)
      -
      -    def _attack(self, document, probs, tokens=None, **kwargs):
      -        """
      -        Gradient Norm Attack. Computes p-norm of gradients w.r.t. model parameters.
      -        """
      -        # We ignore probs here since they are computed in the general case without gradient-tracking (to save memory)
      -
      -        # Hyper-params specific to min-k attack
      -        p: float = kwargs.get("p", np.inf)
      -        if p not in [1, 2, np.inf]:
      -            raise ValueError(f"Invalid p-norm value: {p}.")
      -
      -        # Make sure model params require gradients
      -        # for name, param in self.target_model.model.named_parameters():
      -        #    param.requires_grad = True
      -
      -        # Get gradients for model parameters
      -        self.target_model.model.zero_grad()
      -        all_prob = self.target_model.get_probabilities(document, tokens=tokens, no_grads=False)
      -        loss = - ch.mean(all_prob)
      -        loss.backward()
      -
      -        # Compute p-norm of gradients (for all model params where grad exists)
      -        grad_norms = []
      -        for param in self.target_model.model.parameters():
      -            if param.grad is not None:
      -                grad_norms.append(param.grad.detach().norm(p))
      -        grad_norm = ch.stack(grad_norms).mean()
      -
      -        # Zero out gradients again
      -        self.target_model.model.zero_grad()
      -
      -        return -grad_norm.cpu().numpy()
      -
      @@ -155,7 +106,6 @@

      Inherited members

      MIMIR -

      Index

        @@ -176,7 +126,7 @@

        -

        Generated by pdoc 0.10.0.

        +

        Generated by pdoc 0.11.1.

        - \ No newline at end of file + diff --git a/docs/attacks/index.html b/docs/attacks/index.html index 7d97cc6..a9ed6de 100644 --- a/docs/attacks/index.html +++ b/docs/attacks/index.html @@ -2,18 +2,21 @@ - - + + mimir.attacks API documentation - - - - - - + + + + + + - - + +
        @@ -23,14 +26,6 @@

        Module mimir.attacks

        Attack implementations.

        -
        - -Expand source code - -
        """
        -    Attack implementations.
        -"""
        -

        Sub-modules

        @@ -95,7 +90,6 @@

        Sub-modules

        MIMIR -

        Index

          @@ -124,7 +118,7 @@

          Index

          - \ No newline at end of file + diff --git a/docs/attacks/loss.html b/docs/attacks/loss.html index 9e2b6e0..f37d75e 100644 --- a/docs/attacks/loss.html +++ b/docs/attacks/loss.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.loss API documentation - - - - - - + + + + + + - - + +
          @@ -23,31 +26,6 @@

          Module mimir.attacks.loss

          Straight-forward LOSS attack, as described in https://ieeexplore.ieee.org/abstract/document/8429311

          -
          - -Expand source code - -
          """
          -    Straight-forward LOSS attack, as described in https://ieeexplore.ieee.org/abstract/document/8429311
          -"""
          -import torch as ch
          -from mimir.attacks.all_attacks import Attack
          -from mimir.models import Model
          -from mimir.config import ExperimentConfig
          -
          -
          -class LOSSAttack(Attack):
          -
          -    def __init__(self, config: ExperimentConfig, model: Model):
          -        super().__init__(config, model, ref_model=None)
          -
          -    @ch.no_grad()
          -    def _attack(self, document, probs, tokens=None, **kwargs):
          -        """
          -            LOSS-score. Use log-likelihood from model.
          -        """
          -        return self.target_model.get_ll(document, probs=probs, tokens=tokens)
          -
          @@ -103,7 +81,6 @@

          Inherited members

          MIMIR -

          Index

            @@ -124,7 +101,7 @@

            -

            Generated by pdoc 0.10.0.

            +

            Generated by pdoc 0.11.1.

            - \ No newline at end of file + diff --git a/docs/attacks/min_k.html b/docs/attacks/min_k.html index 375ac50..841311a 100644 --- a/docs/attacks/min_k.html +++ b/docs/attacks/min_k.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.min_k API documentation - - - - - - + + + + + + - - + +
            @@ -23,49 +26,6 @@

            Module mimir.attacks.min_k

            Min-k % Prob Attack: https://arxiv.org/pdf/2310.16789.pdf

            -
            - -Expand source code - -
            """
            -    Min-k % Prob Attack: https://arxiv.org/pdf/2310.16789.pdf
            -"""
            -import torch as ch
            -import numpy as np
            -from mimir.attacks.all_attacks import Attack
            -from mimir.models import Model
            -from mimir.config import ExperimentConfig
            -
            -
            -class MinKProbAttack(Attack):
            -
            -    def __init__(self, config: ExperimentConfig, model: Model):
            -        super().__init__(config, model, ref_model=None)
            -
            -    @ch.no_grad()
            -    def _attack(self, document, probs, tokens=None, **kwargs):
            -        """
            -        Min-k % Prob Attack. Gets model probabilities and returns likelihood when computed over top k% of ngrams.
            -        """
            -        # Hyper-params specific to min-k attack
            -        k: float = kwargs.get("k", 0.2)
            -        window: int = kwargs.get("window", 1)
            -        stride: int = kwargs.get("stride", 1)
            -
            -        all_prob = (
            -            probs
            -            if probs is not None
            -            else self.target_model.get_probabilities(document, tokens=tokens)
            -        )
            -        # iterate through probabilities by ngram defined by window size at given stride
            -        ngram_probs = []
            -        for i in range(0, len(all_prob) - window + 1, stride):
            -            ngram_prob = all_prob[i : i + window]
            -            ngram_probs.append(np.mean(ngram_prob))
            -        min_k_probs = sorted(ngram_probs)[: int(len(ngram_probs) * k)]
            -
            -        return -np.mean(min_k_probs)
            -
            @@ -138,7 +98,6 @@

            Inherited members

            MIMIR -

            Index

              @@ -159,7 +118,7 @@

              -

              Generated by pdoc 0.10.0.

              +

              Generated by pdoc 0.11.1.

              - \ No newline at end of file + diff --git a/docs/attacks/min_k_plus_plus.html b/docs/attacks/min_k_plus_plus.html index cb4bd9f..bced522 100644 --- a/docs/attacks/min_k_plus_plus.html +++ b/docs/attacks/min_k_plus_plus.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.min_k_plus_plus API documentation - - - - - - + + + + + + - - + +
              @@ -23,49 +26,6 @@

              Module mimir.attacks.min_k_plus_plus

              Min-K%++ Attack: https://github.com/zjysteven/mink-plus-plus

              -
              - -Expand source code - -
              """
              -    Min-K%++ Attack: https://github.com/zjysteven/mink-plus-plus
              -"""
              -import torch as ch
              -import numpy as np
              -from mimir.attacks.all_attacks import Attack
              -from mimir.models import Model
              -from mimir.config import ExperimentConfig
              -
              -
              -class MinKPlusPlusAttack(Attack):
              -
              -    def __init__(self, config: ExperimentConfig, model: Model):
              -        super().__init__(config, model, ref_model=None)
              -
              -    @ch.no_grad()
              -    def _attack(self, document, probs, tokens=None, **kwargs):
              -        """
              -        Min-K%++ Attack. 
              -        Gets token probabilties, normalize with the mean and std over the whole categorical distribution,
              -        and returns normalized likelihood when computed over top k% of ngrams.
              -        """
              -        # Hyper-params specific to min-k attack
              -        k: float = kwargs.get("k", 0.2)
              -        all_probs = kwargs.get("all_probs", None)
              -
              -        # these are all log probabilites
              -        target_prob, all_probs = (
              -            (probs, all_probs)
              -            if (probs is not None and all_probs is not None)
              -            else self.model.get_probabilities(document, tokens=tokens, return_all_probs=True)
              -        )
              -        
              -        mu = (ch.exp(all_probs) * all_probs).sum(-1)
              -        sigma = (ch.exp(all_probs) * ch.square(all_probs)).sum(-1) - ch.square(mu)
              -        scores = (np.array(target_prob) - mu.numpy()) / sigma.sqrt().numpy()
              -        
              -        return -np.mean(sorted(scores)[:int(len(scores) * k)])
              -
              @@ -138,7 +98,6 @@

              Inherited members

              MIMIR -

              Index

                @@ -159,7 +118,7 @@

                -

                Generated by pdoc 0.10.0.

                +

                Generated by pdoc 0.11.1.

                - \ No newline at end of file + diff --git a/docs/attacks/neighborhood.html b/docs/attacks/neighborhood.html index 211105d..031a422 100644 --- a/docs/attacks/neighborhood.html +++ b/docs/attacks/neighborhood.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.neighborhood API documentation - - - - - - + + + + + + - - + +
                @@ -23,586 +26,6 @@

                Module mimir.attacks.neighborhood

                Neighborhood-MIA attack https://arxiv.org/pdf/2305.18462.pdf

                -
                - -Expand source code - -
                """
                -    Neighborhood-MIA attack https://arxiv.org/pdf/2305.18462.pdf
                -"""
                -
                -from heapq import nlargest
                -import torch
                -import re
                -import numpy as np
                -from tqdm import tqdm
                -import random
                -import transformers
                -from typing import List
                -
                -from mimir.config import ExperimentConfig
                -from mimir.attacks.attack_utils import count_masks, apply_extracted_fills
                -from mimir.models import Model, ReferenceModel
                -from mimir.attacks.all_attacks import Attack
                -
                -
                -class NeighborhoodAttack(Attack):
                -
                -    def __init__(
                -        self,
                -        config: ExperimentConfig,
                -        target_model: Model,
                -        ref_model: ReferenceModel = None,
                -        **kwargs,
                -    ):
                -        super().__init__(config, target_model, ref_model=None)
                -        self.ref_model = self._pick_neighbor_model()
                -        assert issubclass(type(self.ref_model), MaskFillingModel), "ref_model must be MaskFillingModel for neighborhood attack"
                -
                -    def get_mask_model(self):
                -        """
                -            Return the mask filling model.
                -        """
                -        return self.ref_model
                -
                -    def create_fill_dictionary(self, data):
                -        """
                -            (Only valid for T5 model) Create fill-fictionary used for random_fills
                -        """
                -        neigh_config = self.config.neighborhood_config
                -        if "t5" in neigh_config.model and neigh_config.random_fills:
                -            if not self.config.pretokenized:
                -                # TODO: maybe can be done if detokenized, but currently not supported
                -                self.ref_model.create_fill_dictionary(data)
                -
                -    def _pick_neighbor_model(self):
                -        """
                -            Select and load the mask filling model requested in the config.
                -        """
                -        # mask filling t5 model
                -        mask_model = None
                -        neigh_config = self.config.neighborhood_config
                -        env_config = self.config.env_config
                -
                -        model_kwargs = dict()
                -        if not neigh_config.random_fills:
                -            if env_config.int8:
                -                model_kwargs = dict(
                -                    load_in_8bit=True, device_map="auto", torch_dtype=torch.bfloat16
                -                )
                -            elif env_config.half:
                -                model_kwargs = dict(torch_dtype=torch.bfloat16)
                -            try:
                -                n_positions = (
                -                    512  # Should fix later, but for T-5 this is 512 indeed
                -                )
                -                # mask_model.config.n_positions
                -            except AttributeError:
                -                n_positions = self.config.max_tokens
                -        else:
                -            n_positions = self.config.max_tokens
                -        tokenizer_kwargs = {
                -            "model_max_length": n_positions,
                -        }
                -
                -        print(f"Loading mask filling model {neigh_config.model}...")
                -        if "t5" in neigh_config.model:
                -            mask_model = T5Model(
                -                self.config,
                -                model_kwargs=model_kwargs,
                -                tokenizer_kwargs=tokenizer_kwargs,
                -            )
                -        elif "bert" in neigh_config.model:
                -            mask_model = BertModel(self.config)
                -        else:
                -            raise ValueError(f"Unknown model {neigh_config.model}")
                -        # if config.dataset_member in ['english', 'german']:
                -        #     preproc_tokenizer = mask_tokenizer
                -        return mask_model
                -
                -    def load(self):
                -        """
                -        Any attack-specific steps (one-time) preparation
                -        """
                -        print("MOVING MASK MODEL TO GPU...", end="", flush=True)
                -        self.ref_model.load()
                -
                -    def get_neighbors(self, documents, **kwargs):
                -        """
                -            Generate neighbors for given documents.
                -        """
                -        n_perturbations = kwargs.get("n_perturbations", 1)
                -        span_length = kwargs.get("span_length", 10)
                -        neigh_config = self.config.neighborhood_config
                -        ceil_pct = neigh_config.ceil_pct
                -        kwargs = {}
                -        if type(self.ref_model) == T5Model:
                -            kwargs = {
                -                "span_length": span_length,
                -                "pct": neigh_config.pct_words_masked,
                -                "chunk_size": self.config.chunk_size,
                -                "ceil_pct": ceil_pct,
                -            }
                -        kwargs["n_perturbations"] = n_perturbations
                -
                -        # Generate neighbors
                -        neighbors = self.ref_model.generate_neighbors(documents, **kwargs)
                -        return neighbors
                -
                -    def _attack(self, document, probs, tokens=None, **kwargs):
                -        """
                -        Neighborhood attack score. Looks at difference in likelihood for given document and average likelihood of its neighbors
                -        """
                -        # documents here are actually neighbors
                -        batch_size = kwargs.get("batch_size", 4)
                -        substr_neighbors = kwargs.get("substr_neighbors", None)
                -        loss = kwargs.get("loss", None)
                -        if loss is None:
                -            loss = self.target_model.get_ll(document, probs=probs, tokens=tokens)
                -
                -        # Only evaluate neighborhood attack when not caching neighbors
                -        mean_substr_score = self.target_model.get_lls(
                -            substr_neighbors, batch_size=batch_size
                -        )
                -        d_based_score = loss - mean_substr_score
                -        return d_based_score
                -
                -
                -class MaskFillingModel(Model):
                -    def __init__(self, config: ExperimentConfig, **kwargs):
                -        super(MaskFillingModel, self).__init__(config, **kwargs)
                -        self.device = self.config.env_config.device_aux
                -        self.name = self.config.neighborhood_config.model
                -
                -    def generate_neighbors(self, texts, **kwargs) -> List[str]:
                -        raise NotImplementedError("generate_neighbors not implemented")
                -
                -
                -class T5Model(MaskFillingModel):
                -    def __init__(self, config: ExperimentConfig, **kwargs):
                -        super().__init__(config, **kwargs)
                -        model_kwargs = self.kwargs.get("model_kwargs", {})
                -        tokenizer_kwargs = self.kwargs.get("tokenizer_kwargs", {})
                -
                -        self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
                -            self.name, **model_kwargs, cache_dir=self.cache_dir
                -        )
                -        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                -            self.name, **tokenizer_kwargs, cache_dir=self.cache_dir
                -        )
                -
                -        # define regex to match all <extra_id_*> tokens, where * is an integer
                -        self.pattern = re.compile(r"<extra_id_\d+>")
                -
                -    def create_fill_dictionary(self, data):
                -        self.FILL_DICTIONARY = set()
                -        for texts in data.values():
                -            for text in texts:
                -                self.FILL_DICTIONARY.update(text.split())
                -        self.FILL_DICTIONARY = sorted(list(self.FILL_DICTIONARY))
                -
                -    def tokenize_and_mask(
                -        self, text: str, span_length: int, pct: float, ceil_pct: bool = False
                -    ):
                -        buffer_size = self.config.neighborhood_config.buffer_size
                -
                -        tokens = text.split(" ")
                -        mask_string = "<<<mask>>>"
                -
                -        span_length = min(int(pct * len(tokens)), span_length)
                -        # avoid div zero:
                -
                -        span_length = max(1, span_length)
                -
                -        n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
                -        if ceil_pct:
                -            n_spans = np.ceil(n_spans)
                -        n_spans = int(n_spans)
                -
                -        n_masks = 0
                -        while n_masks < n_spans:
                -            start = np.random.randint(0, max(1, len(tokens) - span_length))
                -            end = start + span_length
                -            search_start = max(0, start - buffer_size)
                -            search_end = min(len(tokens), end + buffer_size)
                -            if mask_string not in tokens[search_start:search_end]:
                -                tokens[start:end] = [mask_string]
                -                n_masks += 1
                -
                -        # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
                -        num_filled = 0
                -        for idx, token in enumerate(tokens):
                -            if token == mask_string:
                -                tokens[idx] = f"<extra_id_{num_filled}>"
                -                num_filled += 1
                -        assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
                -        text = " ".join(tokens)
                -        return text
                -
                -    def extract_fills(self, texts):
                -        # remove <pad> from beginning of each text
                -        texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]
                -
                -        # return the text in between each matched mask token
                -        extracted_fills = [self.pattern.split(x)[1:-1] for x in texts]
                -
                -        # remove whitespace around each fill
                -        extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
                -
                -        return extracted_fills
                -
                -    def replace_masks(self, texts: List[str]):
                -        """
                -        Replace each masked span with a sample from T5 mask_model
                -        """
                -        mask_top_p = self.config.neighborhood_config.top_p
                -        n_expected = count_masks(texts)
                -        stop_id = self.tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
                -        tokens = self.tokenizer(texts, return_tensors="pt", padding=True).to(
                -            self.device
                -        )
                -        outputs = self.model.generate(
                -            **tokens,
                -            max_length=150,
                -            do_sample=True,
                -            top_p=mask_top_p,
                -            num_return_sequences=1,
                -            eos_token_id=stop_id,
                -        )
                -        return self.tokenizer.batch_decode(outputs, skip_special_tokens=False)
                -
                -    def generate_neighbors_(self, texts: List[str], **kwargs):
                -        span_length: int = kwargs.get("span_length")
                -        pct: float = kwargs.get("pct")
                -        ceil_pct: bool = kwargs.get("ceil_pct", False)
                -        base_tokenizer = kwargs.get("base_tokenizer", None)
                -        neigh_config = self.config.neighborhood_config
                -
                -        if not neigh_config.random_fills:
                -            masked_texts = [
                -                self.tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts
                -            ]
                -            raw_fills = self.replace_masks(masked_texts)
                -            extracted_fills = self.extract_fills(raw_fills)
                -            perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
                -            idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
                -
                -            # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
                -            attempts = 1
                -            break_out_of_loop: bool = False
                -            while "" in perturbed_texts:
                -                if attempts > neigh_config.max_tries:
                -                    for idx in idxs:
                -                        perturbed_texts[idx] = texts[idx]
                -                    break_out_of_loop = True
                -                    break
                -                if break_out_of_loop:
                -                    break
                -                idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
                -                print(
                -                    f"WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}]."
                -                )
                -                masked_texts = [
                -                    self.tokenize_and_mask(x, span_length, pct, ceil_pct)
                -                    for idx, x in enumerate(texts)
                -                    if idx in idxs
                -                ]
                -                raw_fills = self.replace_masks(masked_texts)
                -                extracted_fills = self.extract_fills(raw_fills)
                -                new_perturbed_texts = apply_extracted_fills(
                -                    masked_texts, extracted_fills
                -                )
                -                for idx, x in zip(idxs, new_perturbed_texts):
                -                    perturbed_texts[idx] = x
                -                attempts += 1
                -        else:
                -            if neigh_config.random_fills_tokens:
                -                if base_tokenizer is None:
                -                    raise ValueError(
                -                        "base_tokenizer must be provided if random_fills and random_fills_tokens are True"
                -                    )
                -
                -                # tokenize base_tokenizer
                -                tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(
                -                    self.device
                -                )
                -                valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id
                -                replace_pct = neigh_config.pct_words_masked * (
                -                    neigh_config.span_length
                -                    / (neigh_config.span_length + 2 * neigh_config.buffer_size)
                -                )
                -
                -                # replace replace_pct of input_ids with random tokens
                -                random_mask = (
                -                    torch.rand(tokens.input_ids.shape, device=self.device) < replace_pct
                -                )
                -                random_mask &= valid_tokens
                -                random_tokens = torch.randint(
                -                    0,
                -                    base_tokenizer.vocab_size,
                -                    (random_mask.sum(),),
                -                    device=self.device,
                -                )
                -                # while any of the random tokens are special tokens, replace them with random non-special tokens
                -                while any(
                -                    base_tokenizer.decode(x) in base_tokenizer.all_special_tokens
                -                    for x in random_tokens
                -                ):
                -                    random_tokens = torch.randint(
                -                        0,
                -                        base_tokenizer.vocab_size,
                -                        (random_mask.sum(),),
                -                        device=self.device,
                -                    )
                -                tokens.input_ids[random_mask] = random_tokens
                -                perturbed_texts = base_tokenizer.batch_decode(
                -                    tokens.input_ids, skip_special_tokens=True
                -                )
                -            else:
                -                masked_texts = [
                -                    self.tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts
                -                ]
                -                perturbed_texts = masked_texts
                -                # replace each <extra_id_*> with neigh_config.span_length random words from FILL_DICTIONARY
                -                for idx, text in enumerate(perturbed_texts):
                -                    filled_text = text
                -                    for fill_idx in range(count_masks([text])[0]):
                -                        fill = random.sample(self.FILL_DICTIONARY, span_length)
                -                        filled_text = filled_text.replace(
                -                            f"<extra_id_{fill_idx}>", " ".join(fill)
                -                        )
                -                    assert (
                -                        count_masks([filled_text])[0] == 0
                -                    ), "Failed to replace all masks"
                -                    perturbed_texts[idx] = filled_text
                -
                -        return perturbed_texts
                -
                -    def generate_neighbors(self, texts, **kwargs) -> List[str]:
                -        n_neighbors = kwargs.get("n_perturbations", 25)
                -        # Repeat text if T-5 model
                -        texts_use = [x for x in texts for _ in range(n_neighbors)]
                -
                -        chunk_size = self.config.chunk_size
                -        if "11b" in self.config.neighborhood_config.model:
                -            chunk_size //= 2
                -
                -        outputs = []
                -        for i in tqdm(
                -            range(0, len(texts_use), chunk_size), desc="Applying perturbations"
                -        ):
                -            outputs.extend(
                -                self.generate_neighbors_(texts_use[i : i + chunk_size], **kwargs)
                -            )
                -        return outputs
                -
                -
                -class BertModel(MaskFillingModel):
                -    def __init__(self, config: ExperimentConfig, **kwargs):
                -        super().__init__(config, **kwargs)
                -        self.token_dropout = torch.nn.Dropout(p=0.7)
                -        if self.name == "bert":
                -            self.tokenizer = transformers.BertTokenizerFast.from_pretrained(
                -                "bert-base-uncased", cache_dir=self.cache_dir
                -            )
                -            self.model = transformers.BertForMaskedLM.from_pretrained(
                -                "bert-base-uncased", cache_dir=self.cache_dir
                -            )
                -        elif self.name == "distilbert":
                -            self.tokenizer = transformers.DistilBertTokenizer.from_pretrained(
                -                "distilbert-base-uncased", cache_dir=self.cache_dir
                -            )
                -            self.model = transformers.DistilBertForMaskedLM.from_pretrained(
                -                "distilbert-base-uncased", cache_dir=self.cache_dir
                -            )
                -        elif self.name == "roberta":
                -            self.tokenizer = transformers.RobertaTokenizer.from_pretrained(
                -                "roberta-base", cache_dir=self.cache_dir
                -            )
                -            self.model = transformers.RobertaForMaskedLM.from_pretrained(
                -                "roberta-base", cache_dir=self.cache_dir
                -            )
                -        else:
                -            raise ValueError(f"Invalid model name {self.name}")
                -
                -    def generate_neighbors(self, texts, **kwargs) -> List[str]:
                -        neighbors = []
                -        for text in tqdm(texts, desc="Generating neighbors"):
                -            neighbors.extend(self.generate_neighbors_(text, **kwargs))
                -        return neighbors
                -
                -    def generate_neighbors_(self, text: str, **kwargs):
                -        in_place_swap = self.config.neighborhood_config.original_tokenization_swap
                -
                -        tokenizer_output = self.tokenizer(
                -            text,
                -            padding=True,
                -            truncation=True,
                -            return_offsets_mapping=in_place_swap,
                -            max_length=self.config.max_tokens,
                -            return_tensors="pt",
                -        )
                -        text_tokenized = tokenizer_output.input_ids.to(self.device)
                -        n_neighbors = kwargs.get("n_perturbations", 25)
                -        num_tokens = len(text_tokenized[0, :])
                -        n_swap = int(num_tokens * self.config.neighborhood_config.pct_swap_bert)
                -
                -        if in_place_swap:
                -            token_positions = tokenizer_output.offset_mapping[0]
                -
                -        replacements = dict()
                -
                -        target_token_indices = range(1, num_tokens)
                -        for target_token_index in target_token_indices:
                -            target_token = text_tokenized[0, target_token_index]
                -            if self.name == "bert":
                -                embeds = self.model.bert.embeddings(text_tokenized)
                -            elif self.name == "distilbert":
                -                embeds = self.model.distilbert.embeddings(text_tokenized)
                -            elif self.name == "roberta":
                -                embeds = self.model.roberta.embeddings(text_tokenized)
                -
                -            embeds = torch.cat(
                -                (
                -                    embeds[:, :target_token_index, :],
                -                    self.token_dropout(embeds[:, target_token_index, :]).unsqueeze(
                -                        dim=0
                -                    ),
                -                    embeds[:, target_token_index + 1 :, :],
                -                ),
                -                dim=1,
                -            )
                -
                -            token_probs = torch.softmax(self.model(inputs_embeds=embeds).logits, dim=2)
                -
                -            original_prob = token_probs[0, target_token_index, target_token]
                -
                -            top_probabilities, top_candidates = torch.topk(
                -                token_probs[:, target_token_index, :], 6, dim=1
                -            )
                -
                -            for cand, prob in zip(top_candidates[0], top_probabilities[0]):
                -                if not cand == target_token:
                -                    # alt = torch.cat((text_tokenized[:,:target_token_index], torch.LongTensor([cand]).unsqueeze(0).to(device), text_tokenized[:,target_token_index+1:]), dim=1)
                -                    # alt_text = search_tokenizer.batch_decode(alt)[0]
                -                    if original_prob.item() == 1:
                -                        replacements[(target_token_index, cand)] = prob.item() / (
                -                            1 - 0.9
                -                        )
                -                    else:
                -                        replacements[(target_token_index, cand)] = prob.item() / (
                -                            1 - original_prob.item()
                -                        )
                -
                -        if self.config.neighborhood_config.neighbor_strategy == "deterministic":
                -            replacement_keys = nlargest(n_neighbors, replacements, key=replacements.get)
                -            replacements_new = dict()
                -            for rk in replacement_keys:
                -                replacements_new[rk] = replacements[rk]
                -
                -            replacements = replacements_new
                -
                -            # TODO: Not sure if this is needed (perhaps making sure we never take >= 100)? Consider removing later
                -            highest_scored = nlargest(100, replacements, key=replacements.get)
                -
                -            neighbors = []
                -            for single in highest_scored:
                -                target_token_index, cand = single
                -
                -                if in_place_swap:
                -                    # Get indices of original text that we want to swap out
                -                    start, end = token_positions[target_token_index]
                -                    # Get text corresponding to cand token
                -                    fill_in_text = self.tokenizer.decode(cand)
                -                    # Remove any '##' from prefix (since we're doing a plug back into text)
                -                    fill_in_text = fill_in_text.replace("##", "")
                -                    alt_text = text[:start] + fill_in_text + text[end:]
                -                else:
                -                    alt = text_tokenized
                -                    alt = torch.cat(
                -                        (
                -                            alt[:, :target_token_index],
                -                            torch.LongTensor([cand]).unsqueeze(0).to(self.device),
                -                            alt[:, target_token_index + 1 :],
                -                        ),
                -                        dim=1,
                -                    )
                -                    alt_text = self.tokenizer.batch_decode(alt)[0]
                -                    # Remove [CLS] and [SEP] tokens
                -                    alt_text = alt_text.replace("[CLS]", "").replace("[SEP]", "")
                -                    # texts.append((alt_text, replacements[single]))
                -                neighbors.append(alt_text)
                -
                -        elif self.config.neighborhood_config.neighbor_strategy == "random":
                -            if not in_place_swap:
                -                raise ValueError(
                -                    "Random neighbor strategy only works with in_place_swap=True right now"
                -                )
                -
                -            # Make new dict replacements_new with structure [key[0]]: (key[1], value) for all keys in replacements
                -            replacements_new = dict()
                -            for k, v in replacements.items():
                -                if k[0] not in replacements_new:
                -                    replacements_new[k[0]] = []
                -                replacements_new[k[0]].append((k[1].item(), v))
                -            # Sort each entry by score
                -            for k, v in replacements_new.items():
                -                replacements_new[k] = sorted(v, key=lambda x: x[1], reverse=True)
                -
                -            num_trials = int(1e3)
                -            replacements, scores = [], []
                -            for _ in range(num_trials):
                -                # Pick n_swap random positions
                -                swap_positions = np.random.choice(
                -                    list(replacements_new.keys()), n_swap, replace=False
                -                )
                -                # Out of all replacements, pick keys where target_token_index is in swap_positions
                -                picked = [replacements_new[x][0] for x in swap_positions]
                -                # Compute score (sum)
                -                score = sum([x[1] for x in picked])
                -                scores.append(score)
                -                # Also keep track of replacements (position, candidate)
                -                replacements.append(
                -                    [(i, replacements_new[i][0][0]) for i in swap_positions]
                -                )
                -
                -            # Out of all trials, pick n_neighbors combinations (highest scores)
                -            highest_scored = nlargest(
                -                n_neighbors, zip(scores, replacements), key=lambda x: x[0]
                -            )
                -
                -            neighbors = []
                -            for _, single in highest_scored:
                -                # Sort according to target_token_index
                -                single = sorted(single, key=lambda x: x[0])
                -                # Get corresponding positions in text
                -                single = [
                -                    (token_positions[target_token_index], cand)
                -                    for target_token_index, cand in single
                -                ]
                -                # Add start of text (before first swap)
                -                end_prev = 0
                -                alt_text = ""
                -                for (start, end), cand in single:
                -                    # Get text corresponding to cand token
                -                    fill_in_text = self.tokenizer.decode(cand)
                -                    # Remove any '##' from prefix (since we're doing a plug back into text)
                -                    fill_in_text = fill_in_text.replace("##", "")
                -                    alt_text += text[end_prev:start] + fill_in_text
                -                    end_prev = end
                -                # Add remainder text (after last swap)
                -                start, end = single[-1][0]
                -                alt_text += text[end:]
                -                neighbors.append(alt_text)
                -
                -        else:
                -            raise NotImplementedError(
                -                f"Invalid neighbor strategy {self.config.neighborhood_config.neighbor_strategy}"
                -            )
                -
                -        # return texts
                -        return neighbors
                -
                @@ -836,21 +259,6 @@

                Ancestors

              • Model
              • torch.nn.modules.module.Module
              • -

                Class variables

                -
                -
                var call_super_init : bool
                -
                -
                -
                -
                var dump_patches : bool
                -
                -
                -
                -
                var training : bool
                -
                -
                -
                -

                Methods

                @@ -858,205 +266,18 @@

                Methods

                -
                - -Expand source code - -
                def generate_neighbors(self, texts, **kwargs) -> List[str]:
                -    neighbors = []
                -    for text in tqdm(texts, desc="Generating neighbors"):
                -        neighbors.extend(self.generate_neighbors_(text, **kwargs))
                -    return neighbors
                -
                def generate_neighbors_(self, text: str, **kwargs)
                -
                - -Expand source code - -
                def generate_neighbors_(self, text: str, **kwargs):
                -    in_place_swap = self.config.neighborhood_config.original_tokenization_swap
                -
                -    tokenizer_output = self.tokenizer(
                -        text,
                -        padding=True,
                -        truncation=True,
                -        return_offsets_mapping=in_place_swap,
                -        max_length=self.config.max_tokens,
                -        return_tensors="pt",
                -    )
                -    text_tokenized = tokenizer_output.input_ids.to(self.device)
                -    n_neighbors = kwargs.get("n_perturbations", 25)
                -    num_tokens = len(text_tokenized[0, :])
                -    n_swap = int(num_tokens * self.config.neighborhood_config.pct_swap_bert)
                -
                -    if in_place_swap:
                -        token_positions = tokenizer_output.offset_mapping[0]
                -
                -    replacements = dict()
                -
                -    target_token_indices = range(1, num_tokens)
                -    for target_token_index in target_token_indices:
                -        target_token = text_tokenized[0, target_token_index]
                -        if self.name == "bert":
                -            embeds = self.model.bert.embeddings(text_tokenized)
                -        elif self.name == "distilbert":
                -            embeds = self.model.distilbert.embeddings(text_tokenized)
                -        elif self.name == "roberta":
                -            embeds = self.model.roberta.embeddings(text_tokenized)
                -
                -        embeds = torch.cat(
                -            (
                -                embeds[:, :target_token_index, :],
                -                self.token_dropout(embeds[:, target_token_index, :]).unsqueeze(
                -                    dim=0
                -                ),
                -                embeds[:, target_token_index + 1 :, :],
                -            ),
                -            dim=1,
                -        )
                -
                -        token_probs = torch.softmax(self.model(inputs_embeds=embeds).logits, dim=2)
                -
                -        original_prob = token_probs[0, target_token_index, target_token]
                -
                -        top_probabilities, top_candidates = torch.topk(
                -            token_probs[:, target_token_index, :], 6, dim=1
                -        )
                -
                -        for cand, prob in zip(top_candidates[0], top_probabilities[0]):
                -            if not cand == target_token:
                -                # alt = torch.cat((text_tokenized[:,:target_token_index], torch.LongTensor([cand]).unsqueeze(0).to(device), text_tokenized[:,target_token_index+1:]), dim=1)
                -                # alt_text = search_tokenizer.batch_decode(alt)[0]
                -                if original_prob.item() == 1:
                -                    replacements[(target_token_index, cand)] = prob.item() / (
                -                        1 - 0.9
                -                    )
                -                else:
                -                    replacements[(target_token_index, cand)] = prob.item() / (
                -                        1 - original_prob.item()
                -                    )
                -
                -    if self.config.neighborhood_config.neighbor_strategy == "deterministic":
                -        replacement_keys = nlargest(n_neighbors, replacements, key=replacements.get)
                -        replacements_new = dict()
                -        for rk in replacement_keys:
                -            replacements_new[rk] = replacements[rk]
                -
                -        replacements = replacements_new
                -
                -        # TODO: Not sure if this is needed (perhaps making sure we never take >= 100)? Consider removing later
                -        highest_scored = nlargest(100, replacements, key=replacements.get)
                -
                -        neighbors = []
                -        for single in highest_scored:
                -            target_token_index, cand = single
                -
                -            if in_place_swap:
                -                # Get indices of original text that we want to swap out
                -                start, end = token_positions[target_token_index]
                -                # Get text corresponding to cand token
                -                fill_in_text = self.tokenizer.decode(cand)
                -                # Remove any '##' from prefix (since we're doing a plug back into text)
                -                fill_in_text = fill_in_text.replace("##", "")
                -                alt_text = text[:start] + fill_in_text + text[end:]
                -            else:
                -                alt = text_tokenized
                -                alt = torch.cat(
                -                    (
                -                        alt[:, :target_token_index],
                -                        torch.LongTensor([cand]).unsqueeze(0).to(self.device),
                -                        alt[:, target_token_index + 1 :],
                -                    ),
                -                    dim=1,
                -                )
                -                alt_text = self.tokenizer.batch_decode(alt)[0]
                -                # Remove [CLS] and [SEP] tokens
                -                alt_text = alt_text.replace("[CLS]", "").replace("[SEP]", "")
                -                # texts.append((alt_text, replacements[single]))
                -            neighbors.append(alt_text)
                -
                -    elif self.config.neighborhood_config.neighbor_strategy == "random":
                -        if not in_place_swap:
                -            raise ValueError(
                -                "Random neighbor strategy only works with in_place_swap=True right now"
                -            )
                -
                -        # Make new dict replacements_new with structure [key[0]]: (key[1], value) for all keys in replacements
                -        replacements_new = dict()
                -        for k, v in replacements.items():
                -            if k[0] not in replacements_new:
                -                replacements_new[k[0]] = []
                -            replacements_new[k[0]].append((k[1].item(), v))
                -        # Sort each entry by score
                -        for k, v in replacements_new.items():
                -            replacements_new[k] = sorted(v, key=lambda x: x[1], reverse=True)
                -
                -        num_trials = int(1e3)
                -        replacements, scores = [], []
                -        for _ in range(num_trials):
                -            # Pick n_swap random positions
                -            swap_positions = np.random.choice(
                -                list(replacements_new.keys()), n_swap, replace=False
                -            )
                -            # Out of all replacements, pick keys where target_token_index is in swap_positions
                -            picked = [replacements_new[x][0] for x in swap_positions]
                -            # Compute score (sum)
                -            score = sum([x[1] for x in picked])
                -            scores.append(score)
                -            # Also keep track of replacements (position, candidate)
                -            replacements.append(
                -                [(i, replacements_new[i][0][0]) for i in swap_positions]
                -            )
                -
                -        # Out of all trials, pick n_neighbors combinations (highest scores)
                -        highest_scored = nlargest(
                -            n_neighbors, zip(scores, replacements), key=lambda x: x[0]
                -        )
                -
                -        neighbors = []
                -        for _, single in highest_scored:
                -            # Sort according to target_token_index
                -            single = sorted(single, key=lambda x: x[0])
                -            # Get corresponding positions in text
                -            single = [
                -                (token_positions[target_token_index], cand)
                -                for target_token_index, cand in single
                -            ]
                -            # Add start of text (before first swap)
                -            end_prev = 0
                -            alt_text = ""
                -            for (start, end), cand in single:
                -                # Get text corresponding to cand token
                -                fill_in_text = self.tokenizer.decode(cand)
                -                # Remove any '##' from prefix (since we're doing a plug back into text)
                -                fill_in_text = fill_in_text.replace("##", "")
                -                alt_text += text[end_prev:start] + fill_in_text
                -                end_prev = end
                -            # Add remainder text (after last swap)
                -            start, end = single[-1][0]
                -            alt_text += text[end:]
                -            neighbors.append(alt_text)
                -
                -    else:
                -        raise NotImplementedError(
                -            f"Invalid neighbor strategy {self.config.neighborhood_config.neighbor_strategy}"
                -        )
                -
                -    # return texts
                -    return neighbors
                -

                Inherited members

                • MaskFillingModel: -

                  Class variables

                  -
                  -
                  var call_super_init : bool
                  -
                  -
                  -
                  -
                  var dump_patches : bool
                  -
                  -
                  -
                  -
                  var training : bool
                  -
                  -
                  -
                  -

                  Methods

                  @@ -1120,20 +326,12 @@

                  Methods

                  -
                  - -Expand source code - -
                  def generate_neighbors(self, texts, **kwargs) -> List[str]:
                  -    raise NotImplementedError("generate_neighbors not implemented")
                  -

                  Inherited members

                  • Model:
                      -
                    • forward
                    • get_ll
                    • get_probabilities
                    • load
                    • @@ -1287,68 +485,18 @@

                      Methods

                      (Only valid for T5 model) Create fill-fictionary used for random_fills

                      -
                      - -Expand source code - -
                      def create_fill_dictionary(self, data):
                      -    """
                      -        (Only valid for T5 model) Create fill-fictionary used for random_fills
                      -    """
                      -    neigh_config = self.config.neighborhood_config
                      -    if "t5" in neigh_config.model and neigh_config.random_fills:
                      -        if not self.config.pretokenized:
                      -            # TODO: maybe can be done if detokenized, but currently not supported
                      -            self.ref_model.create_fill_dictionary(data)
                      -
                      def get_mask_model(self)

                      Return the mask filling model.

                      -
                      - -Expand source code - -
                      def get_mask_model(self):
                      -    """
                      -        Return the mask filling model.
                      -    """
                      -    return self.ref_model
                      -
                      def get_neighbors(self, documents, **kwargs)

                      Generate neighbors for given documents.

                      -
                      - -Expand source code - -
                      def get_neighbors(self, documents, **kwargs):
                      -    """
                      -        Generate neighbors for given documents.
                      -    """
                      -    n_perturbations = kwargs.get("n_perturbations", 1)
                      -    span_length = kwargs.get("span_length", 10)
                      -    neigh_config = self.config.neighborhood_config
                      -    ceil_pct = neigh_config.ceil_pct
                      -    kwargs = {}
                      -    if type(self.ref_model) == T5Model:
                      -        kwargs = {
                      -            "span_length": span_length,
                      -            "pct": neigh_config.pct_words_masked,
                      -            "chunk_size": self.config.chunk_size,
                      -            "ceil_pct": ceil_pct,
                      -        }
                      -    kwargs["n_perturbations"] = n_perturbations
                      -
                      -    # Generate neighbors
                      -    neighbors = self.ref_model.generate_neighbors(documents, **kwargs)
                      -    return neighbors
                      -

                      Inherited members

                      @@ -1596,21 +744,6 @@

                      Ancestors

                    • Model
                    • torch.nn.modules.module.Module
                    -

                    Class variables

                    -
                    -
                    var call_super_init : bool
                    -
                    -
                    -
                    -
                    var dump_patches : bool
                    -
                    -
                    -
                    -
                    var training : bool
                    -
                    -
                    -
                    -

                    Methods

                    @@ -1618,269 +751,42 @@

                    Methods

                    -
                    - -Expand source code - -
                    def create_fill_dictionary(self, data):
                    -    self.FILL_DICTIONARY = set()
                    -    for texts in data.values():
                    -        for text in texts:
                    -            self.FILL_DICTIONARY.update(text.split())
                    -    self.FILL_DICTIONARY = sorted(list(self.FILL_DICTIONARY))
                    -
                    def extract_fills(self, texts)
                    -
                    - -Expand source code - -
                    def extract_fills(self, texts):
                    -    # remove <pad> from beginning of each text
                    -    texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]
                    -
                    -    # return the text in between each matched mask token
                    -    extracted_fills = [self.pattern.split(x)[1:-1] for x in texts]
                    -
                    -    # remove whitespace around each fill
                    -    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
                    -
                    -    return extracted_fills
                    -
                    def generate_neighbors(self, texts, **kwargs) ‑> List[str]
                    -
                    - -Expand source code - -
                    def generate_neighbors(self, texts, **kwargs) -> List[str]:
                    -    n_neighbors = kwargs.get("n_perturbations", 25)
                    -    # Repeat text if T-5 model
                    -    texts_use = [x for x in texts for _ in range(n_neighbors)]
                    -
                    -    chunk_size = self.config.chunk_size
                    -    if "11b" in self.config.neighborhood_config.model:
                    -        chunk_size //= 2
                    -
                    -    outputs = []
                    -    for i in tqdm(
                    -        range(0, len(texts_use), chunk_size), desc="Applying perturbations"
                    -    ):
                    -        outputs.extend(
                    -            self.generate_neighbors_(texts_use[i : i + chunk_size], **kwargs)
                    -        )
                    -    return outputs
                    -
                    def generate_neighbors_(self, texts: List[str], **kwargs)
                    -
                    - -Expand source code - -
                    def generate_neighbors_(self, texts: List[str], **kwargs):
                    -    span_length: int = kwargs.get("span_length")
                    -    pct: float = kwargs.get("pct")
                    -    ceil_pct: bool = kwargs.get("ceil_pct", False)
                    -    base_tokenizer = kwargs.get("base_tokenizer", None)
                    -    neigh_config = self.config.neighborhood_config
                    -
                    -    if not neigh_config.random_fills:
                    -        masked_texts = [
                    -            self.tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts
                    -        ]
                    -        raw_fills = self.replace_masks(masked_texts)
                    -        extracted_fills = self.extract_fills(raw_fills)
                    -        perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
                    -        idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
                    -
                    -        # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
                    -        attempts = 1
                    -        break_out_of_loop: bool = False
                    -        while "" in perturbed_texts:
                    -            if attempts > neigh_config.max_tries:
                    -                for idx in idxs:
                    -                    perturbed_texts[idx] = texts[idx]
                    -                break_out_of_loop = True
                    -                break
                    -            if break_out_of_loop:
                    -                break
                    -            idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
                    -            print(
                    -                f"WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}]."
                    -            )
                    -            masked_texts = [
                    -                self.tokenize_and_mask(x, span_length, pct, ceil_pct)
                    -                for idx, x in enumerate(texts)
                    -                if idx in idxs
                    -            ]
                    -            raw_fills = self.replace_masks(masked_texts)
                    -            extracted_fills = self.extract_fills(raw_fills)
                    -            new_perturbed_texts = apply_extracted_fills(
                    -                masked_texts, extracted_fills
                    -            )
                    -            for idx, x in zip(idxs, new_perturbed_texts):
                    -                perturbed_texts[idx] = x
                    -            attempts += 1
                    -    else:
                    -        if neigh_config.random_fills_tokens:
                    -            if base_tokenizer is None:
                    -                raise ValueError(
                    -                    "base_tokenizer must be provided if random_fills and random_fills_tokens are True"
                    -                )
                    -
                    -            # tokenize base_tokenizer
                    -            tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(
                    -                self.device
                    -            )
                    -            valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id
                    -            replace_pct = neigh_config.pct_words_masked * (
                    -                neigh_config.span_length
                    -                / (neigh_config.span_length + 2 * neigh_config.buffer_size)
                    -            )
                    -
                    -            # replace replace_pct of input_ids with random tokens
                    -            random_mask = (
                    -                torch.rand(tokens.input_ids.shape, device=self.device) < replace_pct
                    -            )
                    -            random_mask &= valid_tokens
                    -            random_tokens = torch.randint(
                    -                0,
                    -                base_tokenizer.vocab_size,
                    -                (random_mask.sum(),),
                    -                device=self.device,
                    -            )
                    -            # while any of the random tokens are special tokens, replace them with random non-special tokens
                    -            while any(
                    -                base_tokenizer.decode(x) in base_tokenizer.all_special_tokens
                    -                for x in random_tokens
                    -            ):
                    -                random_tokens = torch.randint(
                    -                    0,
                    -                    base_tokenizer.vocab_size,
                    -                    (random_mask.sum(),),
                    -                    device=self.device,
                    -                )
                    -            tokens.input_ids[random_mask] = random_tokens
                    -            perturbed_texts = base_tokenizer.batch_decode(
                    -                tokens.input_ids, skip_special_tokens=True
                    -            )
                    -        else:
                    -            masked_texts = [
                    -                self.tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts
                    -            ]
                    -            perturbed_texts = masked_texts
                    -            # replace each <extra_id_*> with neigh_config.span_length random words from FILL_DICTIONARY
                    -            for idx, text in enumerate(perturbed_texts):
                    -                filled_text = text
                    -                for fill_idx in range(count_masks([text])[0]):
                    -                    fill = random.sample(self.FILL_DICTIONARY, span_length)
                    -                    filled_text = filled_text.replace(
                    -                        f"<extra_id_{fill_idx}>", " ".join(fill)
                    -                    )
                    -                assert (
                    -                    count_masks([filled_text])[0] == 0
                    -                ), "Failed to replace all masks"
                    -                perturbed_texts[idx] = filled_text
                    -
                    -    return perturbed_texts
                    -
                    def replace_masks(self, texts: List[str])

                    Replace each masked span with a sample from T5 mask_model

                    -
                    - -Expand source code - -
                    def replace_masks(self, texts: List[str]):
                    -    """
                    -    Replace each masked span with a sample from T5 mask_model
                    -    """
                    -    mask_top_p = self.config.neighborhood_config.top_p
                    -    n_expected = count_masks(texts)
                    -    stop_id = self.tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
                    -    tokens = self.tokenizer(texts, return_tensors="pt", padding=True).to(
                    -        self.device
                    -    )
                    -    outputs = self.model.generate(
                    -        **tokens,
                    -        max_length=150,
                    -        do_sample=True,
                    -        top_p=mask_top_p,
                    -        num_return_sequences=1,
                    -        eos_token_id=stop_id,
                    -    )
                    -    return self.tokenizer.batch_decode(outputs, skip_special_tokens=False)
                    -
                    def tokenize_and_mask(self, text: str, span_length: int, pct: float, ceil_pct: bool = False)
                    -
                    - -Expand source code - -
                    def tokenize_and_mask(
                    -    self, text: str, span_length: int, pct: float, ceil_pct: bool = False
                    -):
                    -    buffer_size = self.config.neighborhood_config.buffer_size
                    -
                    -    tokens = text.split(" ")
                    -    mask_string = "<<<mask>>>"
                    -
                    -    span_length = min(int(pct * len(tokens)), span_length)
                    -    # avoid div zero:
                    -
                    -    span_length = max(1, span_length)
                    -
                    -    n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
                    -    if ceil_pct:
                    -        n_spans = np.ceil(n_spans)
                    -    n_spans = int(n_spans)
                    -
                    -    n_masks = 0
                    -    while n_masks < n_spans:
                    -        start = np.random.randint(0, max(1, len(tokens) - span_length))
                    -        end = start + span_length
                    -        search_start = max(0, start - buffer_size)
                    -        search_end = min(len(tokens), end + buffer_size)
                    -        if mask_string not in tokens[search_start:search_end]:
                    -            tokens[start:end] = [mask_string]
                    -            n_masks += 1
                    -
                    -    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
                    -    num_filled = 0
                    -    for idx, token in enumerate(tokens):
                    -        if token == mask_string:
                    -            tokens[idx] = f"<extra_id_{num_filled}>"
                    -            num_filled += 1
                    -    assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
                    -    text = " ".join(tokens)
                    -    return text
                    -

                    Inherited members

                    • MaskFillingModel: @@ -1960,7 +856,7 @@

                      -

                      Generated by pdoc 0.10.0.

                      +

                      Generated by pdoc 0.11.1.

                      - \ No newline at end of file + diff --git a/docs/attacks/quantile.html b/docs/attacks/quantile.html index c03cb29..1a9e564 100644 --- a/docs/attacks/quantile.html +++ b/docs/attacks/quantile.html @@ -2,19 +2,22 @@ - - + + mimir.attacks.quantile API documentation - - - - - +https://arxiv.org/pdf/2307.03694.pdf"> + + + + + - - + +
                      @@ -25,127 +28,6 @@

                      Module mimir.attacks.quantile

                      Implementation of the attack proposed in 'Scalable Membership Inference Attacks via Quantile Regression' https://arxiv.org/pdf/2307.03694.pdf

                      -
                      - -Expand source code - -
                      """
                      -    Implementation of the attack proposed in 'Scalable Membership Inference Attacks via Quantile Regression'
                      -    https://arxiv.org/pdf/2307.03694.pdf
                      -"""
                      -import torch as ch
                      -from mimir.models import QuantileReferenceModel, Model
                      -from transformers import TrainingArguments
                      -from sklearn.metrics import mean_squared_error
                      -from transformers import TrainingArguments, Trainer
                      -from datasets import Dataset
                      -
                      -from mimir.attacks.all_attacks import Attack
                      -
                      -
                      -class CustomTrainer(Trainer):
                      -    def __init__(
                      -        self,
                      -        alpha_fpr,
                      -        **kwargs,
                      -    ):
                      -        super().__init__(**kwargs)
                      -        self.alpha_fpr = alpha_fpr
                      -
                      -    def compute_loss(self, model, inputs, return_outputs=False):
                      -        labels = inputs.pop("labels")
                      -        # forward pass
                      -        outputs = model(**inputs)
                      -        logits = outputs.get("logits")
                      -        loss = ch.mean(
                      -            ch.max(
                      -                self.alpha_fpr * (logits - labels),
                      -                (1 - self.alpha_fpr) * (labels - logits),
                      -            )
                      -        )
                      -        return (loss, outputs) if return_outputs else loss
                      -
                      -
                      -class QuantileAttack(Attack):
                      -    """
                      -    Implementation of the attack proposed in 'Scalable Membership Inference Attacks via Quantile Regression'
                      -    https://arxiv.org/pdf/2307.03694.pdf
                      -    """
                      -
                      -    def __init__(self, config, model: Model, alpha: float):
                      -        """
                      -        alpha (float): Desired FPR
                      -        """
                      -        ref_model = QuantileReferenceModel(
                      -            config, name="Sreevishnu/funnel-transformer-small-imdb"
                      -        )
                      -        super().__init__(self, config, model, ref_model)
                      -        self.alpha = alpha
                      -
                      -    def _train_quantile_model(self, dataset):
                      -        def tokenize_function(examples):
                      -            return self.ref_model.tokenizer(
                      -                examples["text"], padding="max_length", truncation=True
                      -            )
                      -
                      -        tokenized_dataset = dataset.map(tokenize_function, batched=True)
                      -        training_args = TrainingArguments(
                      -            output_dir="quantile_ref_model",
                      -            evaluation_strategy="epoch",
                      -            num_train_epochs=1,
                      -        )
                      -
                      -        def compute_metrics(eval_pred):
                      -            predictions, labels = eval_pred
                      -            rmse = mean_squared_error(labels, predictions, squared=False)
                      -            return {"rmse": rmse}
                      -
                      -        trainer = CustomTrainer(
                      -            alpha_fpr=self.alpha,
                      -            model=self.ref_model.model,
                      -            args=training_args,
                      -            train_dataset=tokenized_dataset,
                      -            eval_dataset=tokenized_dataset,
                      -            compute_metrics=compute_metrics,
                      -        )
                      -        # Train quantile model
                      -        trainer.train()
                      -
                      -    def prepare(self, known_non_members):
                      -        """
                      -        Step 1: Use non-member dataset, collect confidence scores for correct label.
                      -        Step 2: Train a quantile regression model that takes X as input and predicts quantile. Use pinball loss
                      -        Step 3: Test by checking if member: score is higher than output of quantile regression model.
                      -        """
                      -
                      -        # Step 1: Use non-member dataset, collect confidence scores for correct label.
                      -        # Get likelihood scores from target model for known_non_members
                      -        # Note that these non-members should be different from the ones in testing
                      -        scores = [self.target_model.get_ll(x) for x in known_non_members]
                      -        # Construct a dataset out of this to be used in Huggingface, with
                      -        # "text" containing the actual data, and "labels" containing the scores
                      -        dataset = Dataset.from_dict({"text": known_non_members, "labels": scores})
                      -
                      -        # Step 2: Train a quantile regression model that takes X as input and predicts quantile. Use pinball loss
                      -        self._train_quantile_model(dataset)
                      -
                      -    def attack(self, document, **kwargs):
                      -        # Step 3: Test by checking if member: score is higher than output of quantile regression model.
                      -
                      -        # Get likelihood score from target model for doc
                      -        ll = self.target_model.get_ll(document)
                      -
                      -        # Return ll - quantile_model(doc)
                      -        tokenized = self.ref_model.tokenizer(document, return_tensors="pt")
                      -        # Shift items in the dictionary to the correct device
                      -        tokenized = {k: v.to(self.ref_model.model.device, non_blocking=True) for k, v in tokenized.items()}
                      -        quantile_score = self.ref_model.model(**tokenized)
                      -        print(quantile_score)
                      -        quantile_score = quantile_score.logits.item()
                      -
                      -        # We want higher score to be non-member
                      -        return quantile_score - ll
                      -
                      @@ -206,7 +88,10 @@

                      Args

                      compute_metrics (Callable[[EvalPrediction], Dict], optional): The function that will be used to compute metrics at evaluation. Must take a [EvalPrediction] and return -a dictionary string to metric values. +a dictionary string to metric values. Note When passing TrainingArgs with batch_eval_metrics set to +True, your compute_metrics function must take a boolean compute_result argument. This will be triggered +after the last eval batch to signal that the function needs to calculate and return the global summary +statistics rather than accumulating the batch-level statistics. callbacks (List of [TrainerCallback], optional): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.

                      @@ -274,23 +159,6 @@

                      Methods

                      How the loss is computed by Trainer. By default, all models return the loss in the first element.

                      Subclass and override for custom behavior.

                      -
                      - -Expand source code - -
                      def compute_loss(self, model, inputs, return_outputs=False):
                      -    labels = inputs.pop("labels")
                      -    # forward pass
                      -    outputs = model(**inputs)
                      -    logits = outputs.get("logits")
                      -    loss = ch.mean(
                      -        ch.max(
                      -            self.alpha_fpr * (logits - labels),
                      -            (1 - self.alpha_fpr) * (labels - logits),
                      -        )
                      -    )
                      -    return (loss, outputs) if return_outputs else loss
                      -
                      @@ -399,28 +267,6 @@

                      Methods

                      Step 1: Use non-member dataset, collect confidence scores for correct label. Step 2: Train a quantile regression model that takes X as input and predicts quantile. Use pinball loss Step 3: Test by checking if member: score is higher than output of quantile regression model.

                      -
                      - -Expand source code - -
                      def prepare(self, known_non_members):
                      -    """
                      -    Step 1: Use non-member dataset, collect confidence scores for correct label.
                      -    Step 2: Train a quantile regression model that takes X as input and predicts quantile. Use pinball loss
                      -    Step 3: Test by checking if member: score is higher than output of quantile regression model.
                      -    """
                      -
                      -    # Step 1: Use non-member dataset, collect confidence scores for correct label.
                      -    # Get likelihood scores from target model for known_non_members
                      -    # Note that these non-members should be different from the ones in testing
                      -    scores = [self.target_model.get_ll(x) for x in known_non_members]
                      -    # Construct a dataset out of this to be used in Huggingface, with
                      -    # "text" containing the actual data, and "labels" containing the scores
                      -    dataset = Dataset.from_dict({"text": known_non_members, "labels": scores})
                      -
                      -    # Step 2: Train a quantile regression model that takes X as input and predicts quantile. Use pinball loss
                      -    self._train_quantile_model(dataset)
                      -

                      Inherited members

                      @@ -442,7 +288,6 @@

                      Inherited members

                      MIMIR -

                      Index

                        @@ -472,7 +317,7 @@

                        -

                        Generated by pdoc 0.10.0.

                        +

                        Generated by pdoc 0.11.1.

                        - \ No newline at end of file + diff --git a/docs/attacks/reference.html b/docs/attacks/reference.html index 7d1dd6d..f8e4029 100644 --- a/docs/attacks/reference.html +++ b/docs/attacks/reference.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.reference API documentation - - - - - - + + + + + + - - + +
                        @@ -23,37 +26,6 @@

                        Module mimir.attacks.reference

                        Reference-based attacks.

                        -
                        - -Expand source code - -
                        """
                        -    Reference-based attacks.
                        -"""
                        -from mimir.attacks.all_attacks import Attack
                        -from mimir.models import Model, ReferenceModel
                        -from mimir.config import ExperimentConfig
                        -
                        -
                        -class ReferenceAttack(Attack):
                        -
                        -    def __init__(
                        -        self, config: ExperimentConfig,
                        -        model: Model,
                        -        reference_model: ReferenceModel
                        -    ):
                        -        super().__init__(config, model, reference_model)
                        -
                        -    def _attack(self, document, probs, tokens=None, **kwargs):
                        -        """
                        -        Reference-based attack score. Performs difficulty calibration in model likelihood using a reference model.
                        -        """
                        -        loss = kwargs.get('loss', None)
                        -        if loss is None:
                        -            loss = self.target_model.get_ll(document, probs=probs, tokens=tokens)
                        -        ref_loss = self.ref_model.get_ll(document, probs=probs, tokens=tokens)
                        -        return loss - ref_loss
                        -
                        @@ -116,7 +88,6 @@

                        Inherited members

                        MIMIR -

                        Index

                          @@ -137,7 +108,7 @@

                          -

                          Generated by pdoc 0.10.0.

                          +

                          Generated by pdoc 0.11.1.

                          - \ No newline at end of file + diff --git a/docs/attacks/utils.html b/docs/attacks/utils.html index 264ee55..3366e73 100644 --- a/docs/attacks/utils.html +++ b/docs/attacks/utils.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.utils API documentation - - - - - - + + + + + + - - + +
                          @@ -22,37 +25,6 @@

                          Module mimir.attacks.utils

                          -
                          - -Expand source code - -
                          from mimir.attacks.all_attacks import AllAttacks
                          -
                          -from mimir.attacks.loss import LOSSAttack
                          -from mimir.attacks.reference import ReferenceAttack
                          -from mimir.attacks.zlib import ZLIBAttack
                          -from mimir.attacks.min_k import MinKProbAttack
                          -from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack
                          -from mimir.attacks.neighborhood import NeighborhoodAttack
                          -from mimir.attacks.gradnorm import GradNormAttack
                          -
                          -
                          -# TODO Use decorators to link attack implementations with enum above
                          -def get_attacker(attack: str):
                          -    mapping = {
                          -        AllAttacks.LOSS: LOSSAttack,
                          -        AllAttacks.REFERENCE_BASED: ReferenceAttack,
                          -        AllAttacks.ZLIB: ZLIBAttack,
                          -        AllAttacks.MIN_K: MinKProbAttack,
                          -        AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
                          -        AllAttacks.NEIGHBOR: NeighborhoodAttack,
                          -        AllAttacks.GRADNORM: GradNormAttack,
                          -    }
                          -    attack_cls = mapping.get(attack, None)
                          -    if attack_cls is None:
                          -        raise ValueError(f"Attack {attack} not found")
                          -    return attack_cls
                          -
                          @@ -66,25 +38,6 @@

                          Functions

                          -
                          - -Expand source code - -
                          def get_attacker(attack: str):
                          -    mapping = {
                          -        AllAttacks.LOSS: LOSSAttack,
                          -        AllAttacks.REFERENCE_BASED: ReferenceAttack,
                          -        AllAttacks.ZLIB: ZLIBAttack,
                          -        AllAttacks.MIN_K: MinKProbAttack,
                          -        AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
                          -        AllAttacks.NEIGHBOR: NeighborhoodAttack,
                          -        AllAttacks.GRADNORM: GradNormAttack,
                          -    }
                          -    attack_cls = mapping.get(attack, None)
                          -    if attack_cls is None:
                          -        raise ValueError(f"Attack {attack} not found")
                          -    return attack_cls
                          -
                          @@ -97,7 +50,6 @@

                          Functions

                          MIMIR -

                          Index

                            @@ -116,7 +68,7 @@

                            Index

                            - \ No newline at end of file + diff --git a/docs/attacks/zlib.html b/docs/attacks/zlib.html index fda63ec..50b24da 100644 --- a/docs/attacks/zlib.html +++ b/docs/attacks/zlib.html @@ -2,18 +2,21 @@ - - + + mimir.attacks.zlib API documentation - - - - - - + + + + + + - - + +
                            @@ -23,46 +26,6 @@

                            Module mimir.attacks.zlib

                            zlib-normalization Attack: https://www.usenix.org/system/files/sec21-carlini-extracting.pdf

                            -
                            - -Expand source code - -
                            """
                            -    zlib-normalization Attack: https://www.usenix.org/system/files/sec21-carlini-extracting.pdf
                            -"""
                            -
                            -import torch as ch
                            -import zlib
                            -
                            -from mimir.attacks.all_attacks import Attack
                            -from mimir.models import Model
                            -from mimir.config import ExperimentConfig
                            -
                            -
                            -class ZLIBAttack(Attack):
                            -
                            -    def __init__(self,
                            -                 config: ExperimentConfig,
                            -                 model: Model):
                            -        super().__init__(config, model, ref_model=None)
                            -
                            -    @ch.no_grad()
                            -    def _attack(
                            -        self,
                            -        document,
                            -        probs,
                            -        tokens=None,
                            -        **kwargs
                            -    ):
                            -        """
                            -        zlib-based attack score. Performs difficulty calibration in model likelihood by normalizing with zlib entropy.
                            -        """
                            -        loss = kwargs.get("loss", None)
                            -        if loss is None:
                            -            loss = self.target_model.get_ll(document, probs=probs, tokens=tokens)
                            -        zlib_entropy = len(zlib.compress(bytes(document, "utf-8")))
                            -        return loss / zlib_entropy
                            -
                            @@ -130,7 +93,6 @@

                            Inherited members

                            MIMIR -

                            Index

                              @@ -151,7 +113,7 @@

                              -

                              Generated by pdoc 0.10.0.

                              +

                              Generated by pdoc 0.11.1.

                              - \ No newline at end of file + diff --git a/docs/config.html b/docs/config.html index 9c1fb4f..e2cb034 100644 --- a/docs/config.html +++ b/docs/config.html @@ -2,18 +2,21 @@ - - + + mimir.config API documentation - - - - - - + + + + + + - - + +
                              @@ -23,229 +26,6 @@

                              Module mimir.config

                              Definitions for configurations.

                              -
                              - -Expand source code - -
                              """
                              -    Definitions for configurations.
                              -"""
                              -
                              -from dataclasses import dataclass
                              -from typing import Optional, List
                              -from simple_parsing.helpers import Serializable, field
                              -from mimir.utils import get_cache_path, get_data_source
                              -
                              -
                              -@dataclass
                              -class ReferenceConfig(Serializable):
                              -    """
                              -    Config for attacks that use reference models.
                              -    """
                              -    models: List[str]
                              -    """Reference model names"""
                              -
                              -
                              -@dataclass
                              -class NeighborhoodConfig(Serializable):
                              -    """
                              -    Config for neighborhood attack
                              -    """
                              -    model: str
                              -    """Mask-filling model"""
                              -    n_perturbation_list: List[int] = field(default_factory=lambda: [1, 10])
                              -    """List of n_neighbors to try."""
                              -    dump_cache: Optional[bool] = False
                              -    "Dump neighbors data to cache? Exits program after dumping"
                              -    load_from_cache: Optional[bool] = False
                              -    """Load neighbors data from cache?"""
                              -    # BERT-specific param
                              -    original_tokenization_swap: Optional[bool] = True
                              -    """Swap out token in original text with neighbor token, instead of re-generating text"""
                              -    pct_swap_bert: Optional[float] = 0.05
                              -    """Percentage of tokens per neighbor that are different from the original text"""
                              -    neighbor_strategy: Optional[str] = "deterministic"
                              -    """Strategy for generating neighbors. One of ['deterministic', 'random']. Deterministic uses only one-word neighbors"""
                              -    # T-5 specific hyper-parameters
                              -    span_length: Optional[int] = 2
                              -    """Span length for neighborhood attack"""
                              -    random_fills_tokens: Optional[bool] = False
                              -    """Randomly fill tokens?"""
                              -    random_fills: Optional[bool] = False
                              -    """Randomly fill?"""
                              -    pct_words_masked: Optional[float] = 0.3
                              -    """Percentage masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size))"""
                              -    buffer_size: Optional[int] = 1
                              -    """Buffer size"""
                              -    top_p: Optional[float] = 1.0
                              -    """Use tokens (minimal set) with cumulative probability of <=top_p"""
                              -    max_tries: Optional[int] = 100
                              -    """Maximum number of trials in finding replacements for masked tokens"""
                              -    ceil_pct: Optional[bool] = False
                              -    """Apply ceil operation on span length calculation?"""
                              -
                              -    def __post_init__(self):
                              -        if self.dump_cache and self.load_from_cache:
                              -            raise ValueError("Cannot dump and load cache at the same time")
                              -
                              -
                              -@dataclass
                              -class EnvironmentConfig(Serializable):
                              -    """
                              -    Config for environment-specific parameters
                              -    """
                              -    cache_dir: Optional[str] = None
                              -    """Path to cache directory"""
                              -    data_source: Optional[str] = None
                              -    """Path where data is stored"""
                              -    device: Optional[str] = 'cuda:0'
                              -    """Device (GPU) to load main model on"""
                              -    device_map: Optional[str] = None
                              -    """Configuration for device map if needing to split model across gpus"""
                              -    device_aux: Optional[str] = "cuda:1"
                              -    """Device (GPU) to load any auxiliary model(s) on"""
                              -    compile: Optional[bool] = True
                              -    """Compile models?"""
                              -    int8: Optional[bool] = False
                              -    """Use int8 quantization?"""
                              -    half: Optional[bool] = False
                              -    """Use half precision?"""
                              -    results: Optional[str] = "results"
                              -    """Path for saving final results"""
                              -    tmp_results: Optional[str] = "tmp_results"
                              -
                              -    def __post_init__(self):
                              -        if self.cache_dir is None:
                              -            self.cache_dir = get_cache_path()
                              -        if self.data_source is None:
                              -            self.data_source = get_data_source()
                              -
                              -
                              -@dataclass
                              -class OpenAIConfig(Serializable):
                              -    """
                              -    Config for OpenAI calls
                              -    """
                              -    key: str
                              -    """OpenAI API key"""
                              -    model: str
                              -    """Model name"""
                              -
                              -
                              -@dataclass
                              -class ExperimentConfig(Serializable):
                              -    """
                              -    Config for attacks
                              -    """
                              -    experiment_name: str
                              -    """Name for the experiment"""
                              -    base_model: str
                              -    """Base model name"""
                              -    dataset_member: str
                              -    """Dataset source for members"""
                              -    dataset_nonmember: str
                              -    """Dataset source for nonmembers"""
                              -    output_name: str = None
                              -    """Output name for sub-directory."""
                              -    dataset_nonmember_other_sources: Optional[List[str]] = field(
                              -        default_factory=lambda: None
                              -    )
                              -    """Dataset sources for nonmembers for which metrics will be computed, using the thresholds derived from the main member/nonmember datasets"""
                              -    pretokenized: Optional[bool] = False
                              -    """Is the data already pretokenized"""
                              -    revision: Optional[str] = None
                              -    """Model revision to use"""
                              -    presampled_dataset_member: Optional[str] = None
                              -    """Path to presampled dataset source for members"""
                              -    presampled_dataset_nonmember: Optional[str] = None
                              -    """Path to presampled dataset source for non-members"""
                              -    token_frequency_map: Optional[
                              -        str
                              -    ] = None  # TODO: Handling auxiliary data structures
                              -    """Path to a pre-computed token frequency map"""
                              -    dataset_key: Optional[str] = None
                              -    """Dataset key"""
                              -    specific_source: Optional[str] = None
                              -    """Specific sub-source to focus on. Only valid for the_pile"""
                              -    full_doc: Optional[bool] = False  # TODO: refactor full_doc design?
                              -    """Determines whether MIA will be performed over entire doc or not"""
                              -    max_substrs: Optional[int] = 20
                              -    """If full_doc, determines the maximum number of sample substrs to evaluate on"""
                              -    dump_cache: Optional[bool] = False
                              -    """Dump data to cache? Exits program after dumping"""
                              -    load_from_cache: Optional[bool] = False
                              -    """Load data from cache?"""
                              -    load_from_hf: Optional[bool] = True
                              -    """Load data from HuggingFace?"""
                              -    blackbox_attacks: Optional[List[str]] = field(
                              -        default_factory=lambda: None
                              -    )  # Can replace with "default" attacks if we want
                              -    """List of attacks to evaluate"""
                              -    tokenization_attack: Optional[bool] = False
                              -    """Run tokenization attack?"""
                              -    quantile_attack: Optional[bool] = False
                              -    """Run quantile attack?"""
                              -    n_samples: Optional[int] = 200
                              -    """Number of records (member and non-member each) to run the attack(s) for"""
                              -    max_tokens: Optional[int] = 512
                              -    """Consider samples with at most these many tokens"""
                              -    max_data: Optional[int] = 5_000
                              -    """Maximum samples to load from data before processing. Helps with efficiency"""
                              -    min_words: Optional[int] = 100
                              -    """Consider documents with at least these many words"""
                              -    max_words: Optional[int] = 200
                              -    """Consider documents with at most these many words"""
                              -    max_words_cutoff: Optional[bool] = True
                              -    """Is max_words a selection criteria (False), or a cutoff added on text (True)?"""
                              -    batch_size: Optional[int] = 50
                              -    """Batch size"""
                              -    chunk_size: Optional[int] = 20
                              -    """Chunk size"""
                              -    scoring_model_name: Optional[str] = None
                              -    """Scoring model (if different from base model)"""
                              -    top_k: Optional[int] = 40
                              -    """Consider only top-k tokens"""
                              -    do_top_k: Optional[bool] = False
                              -    """Use top-k sampling?"""
                              -    top_p: Optional[float] = 0.96
                              -    """Use tokens (minimal set) with cumulative probability of <=top_p"""
                              -    do_top_p: Optional[bool] = False
                              -    """Use top-p sampling?"""
                              -    pre_perturb_pct: Optional[float] = 0.0
                              -    """Percentage of tokens to perturb before attack"""
                              -    pre_perturb_span_length: Optional[int] = 5
                              -    """Span length for pre-perturbation"""
                              -    tok_by_tok: Optional[bool] = False
                              -    """Process data token-wise?"""
                              -    fpr_list: Optional[List[float]] = field(default_factory=lambda: [0.001, 0.01])
                              -    """FPRs at which to compute TPR"""
                              -    random_seed: Optional[int] = 0
                              -    """Random seed"""
                              -    ref_config: Optional[ReferenceConfig] = None
                              -    """Reference model config"""
                              -    neighborhood_config: Optional[NeighborhoodConfig] = None
                              -    """Neighborhood attack config"""
                              -    env_config: Optional[EnvironmentConfig] = None
                              -    """Environment config"""
                              -    openai_config: Optional[OpenAIConfig] = None
                              -    """OpenAI config"""
                              -
                              -    def __post_init__(self):
                              -        if self.dump_cache and (self.load_from_cache or self.load_from_hf):
                              -            raise ValueError("Cannot dump and load cache at the same time")
                              -
                              -        if self.neighborhood_config:
                              -            if (
                              -                self.neighborhood_config.dump_cache
                              -                or self.neighborhood_config.load_from_cache
                              -            ) and not (self.load_from_cache or self.dump_cache or self.load_from_hf):
                              -                raise ValueError(
                              -                    "Using dump/load for neighborhood cache without dumping/loading main cache does not make sense"
                              -                )
                              -
                              -            if self.neighborhood_config.dump_cache and (self.neighborhood_config.load_from_cache or self.load_from_hf):
                              -                raise ValueError("Cannot dump and load neighborhood cache at the same time")    
                              -
                              @@ -869,7 +649,6 @@

                              Class variables

                              MIMIR -

                              Index

                                @@ -988,7 +767,7 @@

                                -

                                Generated by pdoc 0.10.0.

                                +

                                Generated by pdoc 0.11.1.

                                - \ No newline at end of file + diff --git a/docs/custom_datasets.html b/docs/custom_datasets.html index e82f863..39ac586 100644 --- a/docs/custom_datasets.html +++ b/docs/custom_datasets.html @@ -2,18 +2,21 @@ - - + + mimir.custom_datasets API documentation - - - - - - + + + + + + - - + +
                                @@ -23,199 +26,6 @@

                                Module mimir.custom_datasets

                                Helper functions for processing of data (ultimately used for membership inference evaluation)

                                -
                                - -Expand source code - -
                                """
                                -    Helper functions for processing of data (ultimately used for membership inference evaluation)
                                -"""
                                -import random
                                -import datasets
                                -import os
                                -import json
                                -from typing import List
                                -
                                -
                                -SEPARATOR = '<<<SEP>>>'
                                -
                                -DATASETS = ['writing', 'english', 'german', 'pubmed']
                                -
                                -SOURCES_UPLOADED = [
                                -    "arxiv",
                                -    "dm_mathematics",
                                -    "github",
                                -    "hackernews",
                                -    "pile_cc",
                                -    "pubmed_central",
                                -    "wikipedia_(en)",
                                -    "full_pile",
                                -    "c4",
                                -    "temporal_arxiv",
                                -    "temporal_wiki"
                                -]
                                -
                                -
                                -def load_pubmed(cache_dir):
                                -    data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir)
                                -    
                                -    # combine question and long_answer
                                -    data = [f'Question: {q} Answer:{SEPARATOR}{a}' for q, a in zip(data['question'], data['long_answer'])]
                                -
                                -    return data
                                -
                                -
                                -def load_cached(cache_dir,
                                -                data_split: str,
                                -                filename: str,
                                -                min_length: int,
                                -                max_length: int,
                                -                n_samples: int,
                                -                max_tokens: int,
                                -                load_from_hf: bool = False):
                                -    """"
                                -        Read from cache if available. Used for certain pile sources and xsum
                                -        to ensure fairness in comparison across attacks.runs.
                                -    """
                                -    if load_from_hf:
                                -        print("Loading from HuggingFace!")
                                -        data_split = data_split.replace("train", "member")
                                -        data_split = data_split.replace("test", "nonmember")
                                -        if not filename.startswith("the_pile"):
                                -            raise ValueError(f"HuggingFace data only available for The Pile.")
                                -
                                -        for source in SOURCES_UPLOADED:
                                -            # Got a match
                                -            if source in filename and filename.startswith(f"the_pile_{source}"):
                                -                split = filename.split(f"the_pile_{source}")[1]
                                -                if split == "":
                                -                    # The way HF data is uploaded, no split is recorded as "none"
                                -                    split = "none"
                                -                else:
                                -                    # remove the first underscore
                                -                    split = split[1:]
                                -                    # remove '<' , '>'
                                -                    split = split.replace("<", "").replace(">", "")
                                -                    # Remove "_truncated" from the end, if present
                                -                    split = split.rsplit("_truncated", 1)[0]
                                -
                                -                # Load corresponding dataset
                                -                ds = datasets.load_dataset("iamgroot42/mimir", name=source, split=split)
                                -                data = ds[data_split]
                                -                # Check if the number of samples is correct
                                -                if len(data) != n_samples:
                                -                    raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.")
                                -                return data
                                -        # If got here, matching source was not found
                                -        raise ValueError(f"Requested source {filename} not found in HuggingFace data.")
                                -    else:
                                -        file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl")
                                -        if not os.path.exists(file_path):
                                -            raise ValueError(f"Requested cache file {file_path} does not exist")
                                -        data = load_data(file_path)
                                -    return data
                                -
                                -
                                -def load_data(file_path):
                                -    """
                                -        Load data from a given filepath (.jsonl)
                                -    """
                                -    with open(file_path, 'r') as f:
                                -        data = [json.loads(line) for line in f.readlines()]
                                -    return data
                                -
                                -
                                -def dump_to_cache(data: List, cache_dir, path, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int):
                                -    """
                                -        Cache a file (one sample per line)
                                -    """
                                -    # Make sure path directory exists
                                -    subdir = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", path)
                                -    os.makedirs(subdir, exist_ok=True)
                                -    # Dump to file
                                -    # Since each datum has newlines in it potentially, use jsonl format
                                -    save_data(os.path.join(subdir, filename + ".jsonl"), data)
                                -
                                -
                                -def save_data(file_path, data):
                                -    # Since each datum has newlines in it potentially, use jsonl format
                                -    with open(file_path, 'w') as f:
                                -        for datum in data:
                                -            f.write(json.dumps(datum) + "\n")
                                -
                                -
                                -def process_prompt(prompt):
                                -    return prompt.replace('[ WP ]', '').replace('[ OT ]', '')
                                -
                                -
                                -def process_spaces(story):
                                -    return story.replace(
                                -        ' ,', ',').replace(
                                -        ' .', '.').replace(
                                -        ' ?', '?').replace(
                                -        ' !', '!').replace(
                                -        ' ;', ';').replace(
                                -        ' \'', '\'').replace(
                                -        ' ’ ', '\'').replace(
                                -        ' :', ':').replace(
                                -        '<newline>', '\n').replace(
                                -        '`` ', '"').replace(
                                -        ' \'\'', '"').replace(
                                -        '\'\'', '"').replace(
                                -        '.. ', '... ').replace(
                                -        ' )', ')').replace(
                                -        '( ', '(').replace(
                                -        ' n\'t', 'n\'t').replace(
                                -        ' i ', ' I ').replace(
                                -        ' i\'', ' I\'').replace(
                                -        '\\\'', '\'').replace(
                                -        '\n ', '\n').strip()
                                -
                                -
                                -def load_writing(cache_dir=None):
                                -    writing_path = 'data/writingPrompts'
                                -    
                                -    with open(f'{writing_path}/valid.wp_source', 'r') as f:
                                -        prompts = f.readlines()
                                -    with open(f'{writing_path}/valid.wp_target', 'r') as f:
                                -        stories = f.readlines()
                                -    
                                -    prompts = [process_prompt(prompt) for prompt in prompts]
                                -    joined = [process_spaces(prompt + " " + story) for prompt, story in zip(prompts, stories)]
                                -    filtered = [story for story in joined if 'nsfw' not in story and 'NSFW' not in story]
                                -
                                -    random.seed(0)
                                -    random.shuffle(filtered)
                                -
                                -    return filtered
                                -
                                -
                                -def load_language(language, cache_dir):
                                -    # load either the english or german portion of the wmt16 dataset
                                -    assert language in ['en', 'de']
                                -    d = datasets.load_dataset('wmt16', 'de-en', split='train', cache_dir=cache_dir)
                                -    docs = d['translation']
                                -    desired_language_docs = [d[language] for d in docs]
                                -    lens = [len(d.split()) for d in desired_language_docs]
                                -    sub = [d for d, l in zip(desired_language_docs, lens) if l > 100 and l < 150]
                                -    return sub
                                -
                                -
                                -def load_german(cache_dir):
                                -    return load_language('de', cache_dir)
                                -
                                -
                                -def load_english(cache_dir):
                                -    return load_language('en', cache_dir)
                                -
                                -
                                -def load(name, cache_dir, **kwargs):
                                -    if name in DATASETS:
                                -        load_fn = globals()[f'load_{name}']
                                -        return load_fn(cache_dir=cache_dir, **kwargs)
                                -    else:
                                -        raise ValueError(f'Unknown dataset {name}')
                                -
                                @@ -229,38 +39,12 @@

                                Functions

                                Cache a file (one sample per line)

                                -
                                - -Expand source code - -
                                def dump_to_cache(data: List, cache_dir, path, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int):
                                -    """
                                -        Cache a file (one sample per line)
                                -    """
                                -    # Make sure path directory exists
                                -    subdir = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", path)
                                -    os.makedirs(subdir, exist_ok=True)
                                -    # Dump to file
                                -    # Since each datum has newlines in it potentially, use jsonl format
                                -    save_data(os.path.join(subdir, filename + ".jsonl"), data)
                                -
                                def load(name, cache_dir, **kwargs)
                                -
                                - -Expand source code - -
                                def load(name, cache_dir, **kwargs):
                                -    if name in DATASETS:
                                -        load_fn = globals()[f'load_{name}']
                                -        return load_fn(cache_dir=cache_dir, **kwargs)
                                -    else:
                                -        raise ValueError(f'Unknown dataset {name}')
                                -
                                def load_cached(cache_dir, data_split: str, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int, load_from_hf: bool = False) @@ -269,231 +53,60 @@

                                Functions

                                " Read from cache if available. Used for certain pile sources and xsum to ensure fairness in comparison across attacks.runs.

                                -
                                - -Expand source code - -
                                def load_cached(cache_dir,
                                -                data_split: str,
                                -                filename: str,
                                -                min_length: int,
                                -                max_length: int,
                                -                n_samples: int,
                                -                max_tokens: int,
                                -                load_from_hf: bool = False):
                                -    """"
                                -        Read from cache if available. Used for certain pile sources and xsum
                                -        to ensure fairness in comparison across attacks.runs.
                                -    """
                                -    if load_from_hf:
                                -        print("Loading from HuggingFace!")
                                -        data_split = data_split.replace("train", "member")
                                -        data_split = data_split.replace("test", "nonmember")
                                -        if not filename.startswith("the_pile"):
                                -            raise ValueError(f"HuggingFace data only available for The Pile.")
                                -
                                -        for source in SOURCES_UPLOADED:
                                -            # Got a match
                                -            if source in filename and filename.startswith(f"the_pile_{source}"):
                                -                split = filename.split(f"the_pile_{source}")[1]
                                -                if split == "":
                                -                    # The way HF data is uploaded, no split is recorded as "none"
                                -                    split = "none"
                                -                else:
                                -                    # remove the first underscore
                                -                    split = split[1:]
                                -                    # remove '<' , '>'
                                -                    split = split.replace("<", "").replace(">", "")
                                -                    # Remove "_truncated" from the end, if present
                                -                    split = split.rsplit("_truncated", 1)[0]
                                -
                                -                # Load corresponding dataset
                                -                ds = datasets.load_dataset("iamgroot42/mimir", name=source, split=split)
                                -                data = ds[data_split]
                                -                # Check if the number of samples is correct
                                -                if len(data) != n_samples:
                                -                    raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.")
                                -                return data
                                -        # If got here, matching source was not found
                                -        raise ValueError(f"Requested source {filename} not found in HuggingFace data.")
                                -    else:
                                -        file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl")
                                -        if not os.path.exists(file_path):
                                -            raise ValueError(f"Requested cache file {file_path} does not exist")
                                -        data = load_data(file_path)
                                -    return data
                                -
                                def load_data(file_path)

                                Load data from a given filepath (.jsonl)

                                -
                                - -Expand source code - -
                                def load_data(file_path):
                                -    """
                                -        Load data from a given filepath (.jsonl)
                                -    """
                                -    with open(file_path, 'r') as f:
                                -        data = [json.loads(line) for line in f.readlines()]
                                -    return data
                                -
                                def load_english(cache_dir)
                                -
                                - -Expand source code - -
                                def load_english(cache_dir):
                                -    return load_language('en', cache_dir)
                                -
                                def load_german(cache_dir)
                                -
                                - -Expand source code - -
                                def load_german(cache_dir):
                                -    return load_language('de', cache_dir)
                                -
                                def load_language(language, cache_dir)
                                -
                                - -Expand source code - -
                                def load_language(language, cache_dir):
                                -    # load either the english or german portion of the wmt16 dataset
                                -    assert language in ['en', 'de']
                                -    d = datasets.load_dataset('wmt16', 'de-en', split='train', cache_dir=cache_dir)
                                -    docs = d['translation']
                                -    desired_language_docs = [d[language] for d in docs]
                                -    lens = [len(d.split()) for d in desired_language_docs]
                                -    sub = [d for d, l in zip(desired_language_docs, lens) if l > 100 and l < 150]
                                -    return sub
                                -
                                def load_pubmed(cache_dir)
                                -
                                - -Expand source code - -
                                def load_pubmed(cache_dir):
                                -    data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir)
                                -    
                                -    # combine question and long_answer
                                -    data = [f'Question: {q} Answer:{SEPARATOR}{a}' for q, a in zip(data['question'], data['long_answer'])]
                                -
                                -    return data
                                -
                                def load_writing(cache_dir=None)
                                -
                                - -Expand source code - -
                                def load_writing(cache_dir=None):
                                -    writing_path = 'data/writingPrompts'
                                -    
                                -    with open(f'{writing_path}/valid.wp_source', 'r') as f:
                                -        prompts = f.readlines()
                                -    with open(f'{writing_path}/valid.wp_target', 'r') as f:
                                -        stories = f.readlines()
                                -    
                                -    prompts = [process_prompt(prompt) for prompt in prompts]
                                -    joined = [process_spaces(prompt + " " + story) for prompt, story in zip(prompts, stories)]
                                -    filtered = [story for story in joined if 'nsfw' not in story and 'NSFW' not in story]
                                -
                                -    random.seed(0)
                                -    random.shuffle(filtered)
                                -
                                -    return filtered
                                -
                                def process_prompt(prompt)
                                -
                                - -Expand source code - -
                                def process_prompt(prompt):
                                -    return prompt.replace('[ WP ]', '').replace('[ OT ]', '')
                                -
                                def process_spaces(story)
                                -
                                - -Expand source code - -
                                def process_spaces(story):
                                -    return story.replace(
                                -        ' ,', ',').replace(
                                -        ' .', '.').replace(
                                -        ' ?', '?').replace(
                                -        ' !', '!').replace(
                                -        ' ;', ';').replace(
                                -        ' \'', '\'').replace(
                                -        ' ’ ', '\'').replace(
                                -        ' :', ':').replace(
                                -        '<newline>', '\n').replace(
                                -        '`` ', '"').replace(
                                -        ' \'\'', '"').replace(
                                -        '\'\'', '"').replace(
                                -        '.. ', '... ').replace(
                                -        ' )', ')').replace(
                                -        '( ', '(').replace(
                                -        ' n\'t', 'n\'t').replace(
                                -        ' i ', ' I ').replace(
                                -        ' i\'', ' I\'').replace(
                                -        '\\\'', '\'').replace(
                                -        '\n ', '\n').strip()
                                -
                                def save_data(file_path, data)
                                -
                                - -Expand source code - -
                                def save_data(file_path, data):
                                -    # Since each datum has newlines in it potentially, use jsonl format
                                -    with open(file_path, 'w') as f:
                                -        for datum in data:
                                -            f.write(json.dumps(datum) + "\n")
                                -
                                @@ -506,7 +119,6 @@

                                Functions

                                MIMIR -

                                Index

                                  @@ -536,7 +148,7 @@

                                  Index

                                  - \ No newline at end of file + diff --git a/docs/data_utils.html b/docs/data_utils.html index 16ddbc6..27b719f 100644 --- a/docs/data_utils.html +++ b/docs/data_utils.html @@ -2,18 +2,21 @@ - - + + mimir.data_utils API documentation - - - - - - + + + + + + - - + +
                                  @@ -23,342 +26,6 @@

                                  Module mimir.data_utils

                                  Datasets and data-processing utilities

                                  -
                                  - -Expand source code - -
                                  """
                                  -    Datasets and data-processing utilities
                                  -"""
                                  -import datasets
                                  -import numpy as np
                                  -import os
                                  -import mimir.custom_datasets as custom_datasets
                                  -from mimir.config import ExperimentConfig
                                  -from nltk.tokenize import WhitespaceTokenizer
                                  -
                                  -
                                  -class Data:
                                  -    """
                                  -    Data class to load and cache datasets.
                                  -    """
                                  -    def __init__(self, name,
                                  -                 config: ExperimentConfig,
                                  -                 presampled: str = None,
                                  -                 name_key_mapping: dict = {"the_pile": "text", "xsum": "document"}):
                                  -        self.name_key_mapping = name_key_mapping
                                  -        self.config = config
                                  -        self.name = name
                                  -        self.presampled = presampled
                                  -        self.key = (
                                  -            config.dataset_key
                                  -            if config.dataset_key
                                  -            else self.name_key_mapping.get(name, None)
                                  -        )
                                  -        if self.key is None:
                                  -            raise ValueError(
                                  -                f"Key for dataset {name} not provided or found inname_key_mapping"
                                  -            )
                                  -        self.cache_dir = self.config.env_config.cache_dir
                                  -
                                  -    def load_neighbors(
                                  -        self,
                                  -        train: bool,
                                  -        num_neighbors: int,
                                  -        model: str = "bert",
                                  -        in_place_swap: bool = False,
                                  -    ):
                                  -        """
                                  -        Load neighbors from cache (local or from HF)
                                  -        """
                                  -        data_split = "train" if train else "test"
                                  -        data_split += "_neighbors"
                                  -        filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
                                  -            num_neighbors, model
                                  -        )
                                  -        if in_place_swap:
                                  -            filename += "_in_place_swap"
                                  -        data = custom_datasets.load_cached(
                                  -            self.cache_dir,
                                  -            data_split,
                                  -            filename,
                                  -            min_length=self.config.min_words,
                                  -            max_length=self.config.max_words,
                                  -            n_samples=self.config.n_samples,
                                  -            max_tokens=self.config.max_tokens,
                                  -            load_from_hf=self.config.load_from_hf
                                  -        )
                                  -        return data
                                  -
                                  -    def dump_neighbors(
                                  -        self,
                                  -        data,
                                  -        train: bool,
                                  -        num_neighbors: int,
                                  -        model: str = "bert",
                                  -        in_place_swap: bool = False,
                                  -    ):
                                  -        """
                                  -        Dump neighbors to cache local cache.
                                  -        """
                                  -        data_split = "train" if train else "test"
                                  -        data_split += "_neighbors"
                                  -        filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
                                  -            num_neighbors, model
                                  -        )
                                  -        if in_place_swap:
                                  -            filename += "_in_place_swap"
                                  -        custom_datasets.dump_to_cache(
                                  -            data,
                                  -            self.cache_dir,
                                  -            data_split,
                                  -            filename,
                                  -            min_length=self.config.min_words,
                                  -            max_length=self.config.max_words,
                                  -            n_samples=self.config.n_samples,
                                  -            max_tokens=self.config.max_tokens,
                                  -        )
                                  -
                                  -    def load(self, train: bool, mask_tokenizer=None, specific_source: str = None):
                                  -        data_split = "train" if train else "test"
                                  -        n_samples = self.config.n_samples
                                  -
                                  -        # Load from numpy file storing pretokenized sample in a 2d array of shape (num_samples, num_tokens_per_sample)
                                  -        if self.config.pretokenized:
                                  -            assert self.presampled
                                  -            # TODO: Pretokenized full documents (split into substrs) is not currently supported
                                  -            assert not self.config.full_doc
                                  -            data = np.load(self.presampled)
                                  -            return data
                                  -        elif (self.config.load_from_cache or self.config.load_from_hf):
                                  -            # Load from cache, if requested
                                  -            filename = self._get_name_to_save()
                                  -            data = custom_datasets.load_cached(
                                  -                self.cache_dir,
                                  -                data_split,
                                  -                filename,
                                  -                min_length=self.config.min_words,
                                  -                max_length=self.config.max_words,
                                  -                n_samples=self.config.n_samples,
                                  -                max_tokens=self.config.max_tokens,
                                  -                load_from_hf=self.config.load_from_hf
                                  -            )
                                  -            return data
                                  -        else:
                                  -            if self.presampled or self.config.full_doc:
                                  -                print("using presampled data")
                                  -                data = datasets.load_dataset(
                                  -                    "json",
                                  -                    data_files=self.presampled,
                                  -                    split=f"train",
                                  -                    cache_dir=self.cache_dir,
                                  -                )[self.key]
                                  -            elif self.name in custom_datasets.DATASETS:
                                  -                data = custom_datasets.load(self.name)
                                  -            elif self.name == "the_pile":
                                  -                min_load = max(10000, self.config.max_data)
                                  -                data = datasets.load_dataset(
                                  -                    "json",
                                  -                    data_files=os.path.join(
                                  -                        self.config.env_config.data_source,
                                  -                        "pile/00.jsonl.zst" if train else "pile/test.jsonl.zst",
                                  -                    ),
                                  -                    cache_dir=self.cache_dir,
                                  -                    split=f"train[:{min_load}]",
                                  -                )
                                  -                specific_source_use = (
                                  -                    self.config.specific_source
                                  -                    if specific_source is None
                                  -                    else specific_source
                                  -                )
                                  -                data = pile_selection_utility(
                                  -                    data, self.key, wanted_source=specific_source_use
                                  -                )
                                  -            elif "human" in self.name:
                                  -                data = datasets.load_dataset(
                                  -                    self.name, split=f"train[:100]", cache_dir=self.cache_dir
                                  -                )[self.key]
                                  -            elif "nthngdy" in self.name:
                                  -                data = datasets.load_dataset(
                                  -                    self.name, split="test", cache_dir=self.cache_dir
                                  -                )[self.key]
                                  -            else:
                                  -                data = datasets.load_dataset(
                                  -                    self.name, split=f"train", cache_dir=self.cache_dir
                                  -                )[self.key]
                                  -
                                  -        if not self.config.full_doc:
                                  -            # get unique examples
                                  -            # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time
                                  -            # then take just the examples that are <= 512 tokens (for the mask model)
                                  -            # then generate n_samples samples
                                  -            wsp_tokenizer = WhitespaceTokenizer()
                                  -
                                  -            # remove duplicates from the data
                                  -            data = list(dict.fromkeys(data))  # deterministic, as opposed to set()
                                  -
                                  -            whitespace_tokenized_spans = [
                                  -                (x, list(wsp_tokenizer.span_tokenize(x))) for x in data
                                  -            ]
                                  -
                                  -            # Pick samples with at least self.config.min_words words
                                  -            whitespace_tokenized_spans = [
                                  -                x
                                  -                for x in whitespace_tokenized_spans
                                  -                if len(x[1]) >= self.config.min_words
                                  -            ]
                                  -            if len(whitespace_tokenized_spans) == 0:
                                  -                raise ValueError("No examples with length >= min_words")
                                  -
                                  -            if self.config.max_words_cutoff:
                                  -                last_spans = [
                                  -                    x[1][min(self.config.max_words, len(x[1])) - 1][1]
                                  -                    for x in whitespace_tokenized_spans
                                  -                ]
                                  -                data = [
                                  -                    x[0][:y] for x, y in zip(whitespace_tokenized_spans, last_spans)
                                  -                ]
                                  -            else:
                                  -                data = [
                                  -                    x[0]
                                  -                    for x in whitespace_tokenized_spans
                                  -                    if len(x[1]) < self.config.max_words
                                  -                ]
                                  -                if len(data) == 0:
                                  -                    raise ValueError("No examples with length < max_words")
                                  -
                                  -            # TODO: why shuffle
                                  -            # random.seed(0)
                                  -            # random.shuffle(data)
                                  -
                                  -            data = data[: self.config.max_data]
                                  -
                                  -            # If there is mask tokenizer, keep only examples with <= 512 tokens according to mask_tokenizer
                                  -            # this step has the extra effect of removing examples with low-quality/garbage content
                                  -            if mask_tokenizer:
                                  -                tokenized_data = mask_tokenizer(data)
                                  -                new_data = []
                                  -                for i, (x, y) in enumerate(zip(data, tokenized_data["input_ids"])):
                                  -                    if len(y) <= self.config.max_tokens:
                                  -                        new_data.append(x)
                                  -                    else:
                                  -                        print(
                                  -                            "Trimming text to nearest word that fits within mask tokenizer window"
                                  -                        )
                                  -                        max_token_char_span = tokenized_data.token_to_chars(
                                  -                            i, self.config.max_tokens - 1
                                  -                        )
                                  -                        x = x[: max_token_char_span.end]
                                  -                        token_truncated_word_spans = list(
                                  -                            wsp_tokenizer.span_tokenize(x)
                                  -                        )
                                  -
                                  -                        # Pop off the last "word" since it may be a word piece
                                  -                        second_last_span = token_truncated_word_spans[-2]
                                  -                        x = x[: second_last_span[1]]
                                  -
                                  -                        new_len = len(mask_tokenizer(x)["input_ids"])
                                  -                        assert new_len <= self.config.max_tokens
                                  -                        new_data.append(x)
                                  -                data = new_data
                                  -
                                  -            # print stats about remainining data
                                  -            print(f"Total number of samples: {len(data)}")
                                  -            print(f"Average number of words: {np.mean([len(x.split()) for x in data])}")
                                  -
                                  -            if n_samples > len(data):
                                  -                print(f"WARNING: n_samples ({n_samples}) > len(data) ({len(data)})")
                                  -
                                  -        # Sample 'n_samples' examples
                                  -        data = data[:n_samples]
                                  -
                                  -        # Save to cache (if requested)
                                  -        if self.config.dump_cache:
                                  -            self.dump_to_cache(data, data_split)
                                  -
                                  -        return data
                                  -
                                  -    def dump_to_cache(self, data, data_split):
                                  -        filename = self._get_name_to_save()
                                  -        custom_datasets.dump_to_cache(
                                  -            data,
                                  -            self.cache_dir,
                                  -            data_split,
                                  -            filename,
                                  -            min_length=self.config.min_words,
                                  -            max_length=self.config.max_words,
                                  -            n_samples=self.config.n_samples,
                                  -            max_tokens=self.config.max_tokens,
                                  -        )
                                  -
                                  -    def _get_name_to_save(self):
                                  -        if self.config.specific_source and self.name == "the_pile":
                                  -            processed_source = sourcename_process(self.config.specific_source)
                                  -            filename = f"{self.name}_{processed_source}"
                                  -        else:
                                  -            filename = self.name
                                  -        return filename
                                  -
                                  -
                                  -def strip_newlines(text):
                                  -    """
                                  -    Strip newlines from each example; replace one or more newlines with a single space
                                  -    """
                                  -    return " ".join(text.split())
                                  -
                                  -
                                  -def trim_to_shorter_length(text_a: str, text_b: str, max_length: int = None):
                                  -    """
                                  -    Truncate to shorter of o and s
                                  -    """
                                  -    shorter_length = min(len(text_a.split(" ")), len(text_b.split(" ")))
                                  -    if max_length is not None:
                                  -        shorter_length = min(shorter_length, max_length)
                                  -    text_a = " ".join(text_a.split(" ")[:shorter_length])
                                  -    text_b = " ".join(text_b.split(" ")[:shorter_length])
                                  -    return text_a, text_b
                                  -
                                  -
                                  -def truncate_to_substring(text: str, substring: str, idx_occurrence: int):
                                  -    """
                                  -    Truncate everything after the idx_occurrence occurrence of substring
                                  -    """
                                  -    assert idx_occurrence > 0, "idx_occurrence must be > 0"
                                  -    idx = -1
                                  -    for _ in range(idx_occurrence):
                                  -        idx = text.find(substring, idx + 1)
                                  -        if idx == -1:
                                  -            return text
                                  -    return text[:idx]
                                  -
                                  -
                                  -def pile_selection_utility(data, key: str, wanted_source: str = None):
                                  -    """
                                  -    Filter and select data corresponding to source, if requested.
                                  -    """
                                  -    if wanted_source is None:
                                  -        return data[key]
                                  -    wanted_data = []
                                  -    # Pick sources that match requested source
                                  -    for datum in data:
                                  -        if datum["meta"]["pile_set_name"] == wanted_source:
                                  -            wanted_data.append(datum[key])
                                  -    return wanted_data
                                  -
                                  -
                                  -def sourcename_process(x: str):
                                  -    """
                                  -        Helper function to process source name.
                                  -    """
                                  -    return x.replace(" ", "_").replace("-", "_").lower()
                                  -
                                  -
                                  -def drop_last_word(text):
                                  -    """
                                  -        Drop the last word from a given text.
                                  -    """
                                  -    return " ".join(text.split(" ")[:-1])
                                  -
                                  @@ -372,114 +39,36 @@

                                  Functions

                                  Drop the last word from a given text.

                                  -
                                  - -Expand source code - -
                                  def drop_last_word(text):
                                  -    """
                                  -        Drop the last word from a given text.
                                  -    """
                                  -    return " ".join(text.split(" ")[:-1])
                                  -
                                  def pile_selection_utility(data, key: str, wanted_source: str = None)

                                  Filter and select data corresponding to source, if requested.

                                  -
                                  - -Expand source code - -
                                  def pile_selection_utility(data, key: str, wanted_source: str = None):
                                  -    """
                                  -    Filter and select data corresponding to source, if requested.
                                  -    """
                                  -    if wanted_source is None:
                                  -        return data[key]
                                  -    wanted_data = []
                                  -    # Pick sources that match requested source
                                  -    for datum in data:
                                  -        if datum["meta"]["pile_set_name"] == wanted_source:
                                  -            wanted_data.append(datum[key])
                                  -    return wanted_data
                                  -
                                  def sourcename_process(x: str)

                                  Helper function to process source name.

                                  -
                                  - -Expand source code - -
                                  def sourcename_process(x: str):
                                  -    """
                                  -        Helper function to process source name.
                                  -    """
                                  -    return x.replace(" ", "_").replace("-", "_").lower()
                                  -
                                  def strip_newlines(text)

                                  Strip newlines from each example; replace one or more newlines with a single space

                                  -
                                  - -Expand source code - -
                                  def strip_newlines(text):
                                  -    """
                                  -    Strip newlines from each example; replace one or more newlines with a single space
                                  -    """
                                  -    return " ".join(text.split())
                                  -
                                  def trim_to_shorter_length(text_a: str, text_b: str, max_length: int = None)

                                  Truncate to shorter of o and s

                                  -
                                  - -Expand source code - -
                                  def trim_to_shorter_length(text_a: str, text_b: str, max_length: int = None):
                                  -    """
                                  -    Truncate to shorter of o and s
                                  -    """
                                  -    shorter_length = min(len(text_a.split(" ")), len(text_b.split(" ")))
                                  -    if max_length is not None:
                                  -        shorter_length = min(shorter_length, max_length)
                                  -    text_a = " ".join(text_a.split(" ")[:shorter_length])
                                  -    text_b = " ".join(text_b.split(" ")[:shorter_length])
                                  -    return text_a, text_b
                                  -
                                  def truncate_to_substring(text: str, substring: str, idx_occurrence: int)

                                  Truncate everything after the idx_occurrence occurrence of substring

                                  -
                                  - -Expand source code - -
                                  def truncate_to_substring(text: str, substring: str, idx_occurrence: int):
                                  -    """
                                  -    Truncate everything after the idx_occurrence occurrence of substring
                                  -    """
                                  -    assert idx_occurrence > 0, "idx_occurrence must be > 0"
                                  -    idx = -1
                                  -    for _ in range(idx_occurrence):
                                  -        idx = text.find(substring, idx + 1)
                                  -        if idx == -1:
                                  -            return text
                                  -    return text[:idx]
                                  -
                                  @@ -764,270 +353,24 @@

                                  Methods

                                  Dump neighbors to cache local cache.

                                  -
                                  - -Expand source code - -
                                  def dump_neighbors(
                                  -    self,
                                  -    data,
                                  -    train: bool,
                                  -    num_neighbors: int,
                                  -    model: str = "bert",
                                  -    in_place_swap: bool = False,
                                  -):
                                  -    """
                                  -    Dump neighbors to cache local cache.
                                  -    """
                                  -    data_split = "train" if train else "test"
                                  -    data_split += "_neighbors"
                                  -    filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
                                  -        num_neighbors, model
                                  -    )
                                  -    if in_place_swap:
                                  -        filename += "_in_place_swap"
                                  -    custom_datasets.dump_to_cache(
                                  -        data,
                                  -        self.cache_dir,
                                  -        data_split,
                                  -        filename,
                                  -        min_length=self.config.min_words,
                                  -        max_length=self.config.max_words,
                                  -        n_samples=self.config.n_samples,
                                  -        max_tokens=self.config.max_tokens,
                                  -    )
                                  -
                                  def dump_to_cache(self, data, data_split)
                                  -
                                  - -Expand source code - -
                                  def dump_to_cache(self, data, data_split):
                                  -    filename = self._get_name_to_save()
                                  -    custom_datasets.dump_to_cache(
                                  -        data,
                                  -        self.cache_dir,
                                  -        data_split,
                                  -        filename,
                                  -        min_length=self.config.min_words,
                                  -        max_length=self.config.max_words,
                                  -        n_samples=self.config.n_samples,
                                  -        max_tokens=self.config.max_tokens,
                                  -    )
                                  -
                                  def load(self, train: bool, mask_tokenizer=None, specific_source: str = None)
                                  -
                                  - -Expand source code - -
                                  def load(self, train: bool, mask_tokenizer=None, specific_source: str = None):
                                  -    data_split = "train" if train else "test"
                                  -    n_samples = self.config.n_samples
                                  -
                                  -    # Load from numpy file storing pretokenized sample in a 2d array of shape (num_samples, num_tokens_per_sample)
                                  -    if self.config.pretokenized:
                                  -        assert self.presampled
                                  -        # TODO: Pretokenized full documents (split into substrs) is not currently supported
                                  -        assert not self.config.full_doc
                                  -        data = np.load(self.presampled)
                                  -        return data
                                  -    elif (self.config.load_from_cache or self.config.load_from_hf):
                                  -        # Load from cache, if requested
                                  -        filename = self._get_name_to_save()
                                  -        data = custom_datasets.load_cached(
                                  -            self.cache_dir,
                                  -            data_split,
                                  -            filename,
                                  -            min_length=self.config.min_words,
                                  -            max_length=self.config.max_words,
                                  -            n_samples=self.config.n_samples,
                                  -            max_tokens=self.config.max_tokens,
                                  -            load_from_hf=self.config.load_from_hf
                                  -        )
                                  -        return data
                                  -    else:
                                  -        if self.presampled or self.config.full_doc:
                                  -            print("using presampled data")
                                  -            data = datasets.load_dataset(
                                  -                "json",
                                  -                data_files=self.presampled,
                                  -                split=f"train",
                                  -                cache_dir=self.cache_dir,
                                  -            )[self.key]
                                  -        elif self.name in custom_datasets.DATASETS:
                                  -            data = custom_datasets.load(self.name)
                                  -        elif self.name == "the_pile":
                                  -            min_load = max(10000, self.config.max_data)
                                  -            data = datasets.load_dataset(
                                  -                "json",
                                  -                data_files=os.path.join(
                                  -                    self.config.env_config.data_source,
                                  -                    "pile/00.jsonl.zst" if train else "pile/test.jsonl.zst",
                                  -                ),
                                  -                cache_dir=self.cache_dir,
                                  -                split=f"train[:{min_load}]",
                                  -            )
                                  -            specific_source_use = (
                                  -                self.config.specific_source
                                  -                if specific_source is None
                                  -                else specific_source
                                  -            )
                                  -            data = pile_selection_utility(
                                  -                data, self.key, wanted_source=specific_source_use
                                  -            )
                                  -        elif "human" in self.name:
                                  -            data = datasets.load_dataset(
                                  -                self.name, split=f"train[:100]", cache_dir=self.cache_dir
                                  -            )[self.key]
                                  -        elif "nthngdy" in self.name:
                                  -            data = datasets.load_dataset(
                                  -                self.name, split="test", cache_dir=self.cache_dir
                                  -            )[self.key]
                                  -        else:
                                  -            data = datasets.load_dataset(
                                  -                self.name, split=f"train", cache_dir=self.cache_dir
                                  -            )[self.key]
                                  -
                                  -    if not self.config.full_doc:
                                  -        # get unique examples
                                  -        # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time
                                  -        # then take just the examples that are <= 512 tokens (for the mask model)
                                  -        # then generate n_samples samples
                                  -        wsp_tokenizer = WhitespaceTokenizer()
                                  -
                                  -        # remove duplicates from the data
                                  -        data = list(dict.fromkeys(data))  # deterministic, as opposed to set()
                                  -
                                  -        whitespace_tokenized_spans = [
                                  -            (x, list(wsp_tokenizer.span_tokenize(x))) for x in data
                                  -        ]
                                  -
                                  -        # Pick samples with at least self.config.min_words words
                                  -        whitespace_tokenized_spans = [
                                  -            x
                                  -            for x in whitespace_tokenized_spans
                                  -            if len(x[1]) >= self.config.min_words
                                  -        ]
                                  -        if len(whitespace_tokenized_spans) == 0:
                                  -            raise ValueError("No examples with length >= min_words")
                                  -
                                  -        if self.config.max_words_cutoff:
                                  -            last_spans = [
                                  -                x[1][min(self.config.max_words, len(x[1])) - 1][1]
                                  -                for x in whitespace_tokenized_spans
                                  -            ]
                                  -            data = [
                                  -                x[0][:y] for x, y in zip(whitespace_tokenized_spans, last_spans)
                                  -            ]
                                  -        else:
                                  -            data = [
                                  -                x[0]
                                  -                for x in whitespace_tokenized_spans
                                  -                if len(x[1]) < self.config.max_words
                                  -            ]
                                  -            if len(data) == 0:
                                  -                raise ValueError("No examples with length < max_words")
                                  -
                                  -        # TODO: why shuffle
                                  -        # random.seed(0)
                                  -        # random.shuffle(data)
                                  -
                                  -        data = data[: self.config.max_data]
                                  -
                                  -        # If there is mask tokenizer, keep only examples with <= 512 tokens according to mask_tokenizer
                                  -        # this step has the extra effect of removing examples with low-quality/garbage content
                                  -        if mask_tokenizer:
                                  -            tokenized_data = mask_tokenizer(data)
                                  -            new_data = []
                                  -            for i, (x, y) in enumerate(zip(data, tokenized_data["input_ids"])):
                                  -                if len(y) <= self.config.max_tokens:
                                  -                    new_data.append(x)
                                  -                else:
                                  -                    print(
                                  -                        "Trimming text to nearest word that fits within mask tokenizer window"
                                  -                    )
                                  -                    max_token_char_span = tokenized_data.token_to_chars(
                                  -                        i, self.config.max_tokens - 1
                                  -                    )
                                  -                    x = x[: max_token_char_span.end]
                                  -                    token_truncated_word_spans = list(
                                  -                        wsp_tokenizer.span_tokenize(x)
                                  -                    )
                                  -
                                  -                    # Pop off the last "word" since it may be a word piece
                                  -                    second_last_span = token_truncated_word_spans[-2]
                                  -                    x = x[: second_last_span[1]]
                                  -
                                  -                    new_len = len(mask_tokenizer(x)["input_ids"])
                                  -                    assert new_len <= self.config.max_tokens
                                  -                    new_data.append(x)
                                  -            data = new_data
                                  -
                                  -        # print stats about remainining data
                                  -        print(f"Total number of samples: {len(data)}")
                                  -        print(f"Average number of words: {np.mean([len(x.split()) for x in data])}")
                                  -
                                  -        if n_samples > len(data):
                                  -            print(f"WARNING: n_samples ({n_samples}) > len(data) ({len(data)})")
                                  -
                                  -    # Sample 'n_samples' examples
                                  -    data = data[:n_samples]
                                  -
                                  -    # Save to cache (if requested)
                                  -    if self.config.dump_cache:
                                  -        self.dump_to_cache(data, data_split)
                                  -
                                  -    return data
                                  -
                                  def load_neighbors(self, train: bool, num_neighbors: int, model: str = 'bert', in_place_swap: bool = False)

                                  Load neighbors from cache (local or from HF)

                                  -
                                  - -Expand source code - -
                                  def load_neighbors(
                                  -    self,
                                  -    train: bool,
                                  -    num_neighbors: int,
                                  -    model: str = "bert",
                                  -    in_place_swap: bool = False,
                                  -):
                                  -    """
                                  -    Load neighbors from cache (local or from HF)
                                  -    """
                                  -    data_split = "train" if train else "test"
                                  -    data_split += "_neighbors"
                                  -    filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
                                  -        num_neighbors, model
                                  -    )
                                  -    if in_place_swap:
                                  -        filename += "_in_place_swap"
                                  -    data = custom_datasets.load_cached(
                                  -        self.cache_dir,
                                  -        data_split,
                                  -        filename,
                                  -        min_length=self.config.min_words,
                                  -        max_length=self.config.max_words,
                                  -        n_samples=self.config.n_samples,
                                  -        max_tokens=self.config.max_tokens,
                                  -        load_from_hf=self.config.load_from_hf
                                  -    )
                                  -    return data
                                  -
                                  @@ -1040,7 +383,6 @@

                                  Methods

                                  MIMIR -

                                  Index

                                    @@ -1077,7 +419,7 @@

                                    Data

                                    - \ No newline at end of file + diff --git a/docs/index.html b/docs/index.html index bf42f4f..78d7a60 100644 --- a/docs/index.html +++ b/docs/index.html @@ -2,18 +2,21 @@ - - + + mimir API documentation - - - - - - + + + + + + - - + +
                                    @@ -70,7 +73,6 @@

                                    Sub-modules

                                    MIMIR -

                                    Index

                                      @@ -90,7 +92,7 @@

                                      Index

                                      - \ No newline at end of file + diff --git a/docs/models.html b/docs/models.html index 568f229..9985c38 100644 --- a/docs/models.html +++ b/docs/models.html @@ -2,18 +2,21 @@ - - + + mimir.models API documentation - - - - - - + + + + + + - - + +
                                      @@ -23,635 +26,6 @@

                                      Module mimir.models

                                      Model definitions, with basic helper functions. Supports any model as long as it supports the functions specified in Model.

                                      -
                                      - -Expand source code - -
                                      """
                                      -    Model definitions, with basic helper functions. Supports any model as long as it supports the functions specified in Model.
                                      -"""
                                      -import torch
                                      -import torch.nn as nn
                                      -import openai
                                      -from typing import List
                                      -import numpy as np
                                      -import transformers
                                      -import time
                                      -from collections import defaultdict
                                      -from multiprocessing.pool import ThreadPool
                                      -import torch.nn.functional as F
                                      -from transformers import AutoTokenizer, AutoModelForSequenceClassification
                                      -from hf_olmo import *
                                      -
                                      -from mimir.config import ExperimentConfig
                                      -from mimir.custom_datasets import SEPARATOR
                                      -from mimir.data_utils import drop_last_word
                                      -
                                      -
                                      -class Model(nn.Module):
                                      -    """
                                      -        Base class (for LLMs).
                                      -    """
                                      -    def __init__(self, config: ExperimentConfig, **kwargs):
                                      -        super().__init__()
                                      -        self.model = None # Set by child class
                                      -        self.tokenizer = None # Set by child class
                                      -        self.config = config
                                      -        self.device = None
                                      -        self.device_map = None
                                      -        self.name = None
                                      -        self.kwargs = kwargs
                                      -        self.cache_dir = self.config.env_config.cache_dir
                                      -
                                      -    def to(self, device):
                                      -        """
                                      -            Shift model to a particular device.
                                      -        """
                                      -        self.model.to(device, non_blocking=True)
                                      -
                                      -    def load(self):
                                      -        """
                                      -            Load model onto GPU (and compile, if requested) if not already loaded with device map.
                                      -        """
                                      -        if not self.device_map:
                                      -            start = time.time()
                                      -            try:
                                      -                self.model.cpu()
                                      -            except NameError:
                                      -                pass
                                      -            if self.config.openai_config is None:
                                      -                self.model.to(self.device, non_blocking=True)
                                      -            if self.config.env_config.compile:
                                      -                torch.compile(self.model)
                                      -            print(f'DONE ({time.time() - start:.2f}s)')
                                      -
                                      -    def unload(self):
                                      -        """
                                      -            Unload model from GPU
                                      -        """
                                      -        start = time.time()
                                      -        try:
                                      -            self.model.cpu()
                                      -        except NameError:
                                      -            pass
                                      -        print(f'DONE ({time.time() - start:.2f}s)')
                                      -
                                      -    def get_probabilities(self,
                                      -                          text: str,
                                      -                          tokens: np.ndarray = None,
                                      -                          no_grads: bool = True,
                                      -                          return_all_probs: bool = False):
                                      -        """
                                      -            Get the probabilities or log-softmaxed logits for a text under the current model.
                                      -            Args:
                                      -                text (str): The input text for which to calculate probabilities.
                                      -                tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
                                      -                are used instead of tokenizing the input text. Defaults to None.
                                      -
                                      -            Raises:
                                      -                ValueError: If the device or name attributes of the instance are not set.
                                      -
                                      -            Returns:
                                      -                list: A list of probabilities.
                                      -        """
                                      -        with torch.set_grad_enabled(not no_grads):
                                      -            if self.device is None or self.name is None:
                                      -                raise ValueError("Please set self.device and self.name in child class")
                                      -
                                      -            if tokens is not None:
                                      -                labels = torch.from_numpy(tokens.astype(np.int64)).type(torch.LongTensor)
                                      -                if labels.shape[0] != 1:
                                      -                    # expand first dimension
                                      -                    labels = labels.unsqueeze(0)
                                      -            else:
                                      -                tokenized = self.tokenizer(
                                      -                    text, return_tensors="pt")
                                      -                labels = tokenized.input_ids
                                      -
                                      -            target_token_log_prob = []
                                      -            all_token_log_prob = []
                                      -            for i in range(0, labels.size(1), self.stride):
                                      -                begin_loc = max(i + self.stride - self.max_length, 0)
                                      -                end_loc = min(i + self.stride, labels.size(1))
                                      -                trg_len = end_loc - i  # may be different from stride on last loop
                                      -                input_ids = labels[:, begin_loc:end_loc].to(self.device)
                                      -                target_ids = input_ids.clone()
                                      -                target_ids[:, :-trg_len] = -100
                                      -
                                      -                logits = self.model(input_ids, labels=target_ids).logits
                                      -                if no_grads:
                                      -                    logits = logits.cpu()
                                      -                shift_logits = logits[..., :-1, :].contiguous()
                                      -                log_probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                      -                shift_labels = target_ids[..., 1:]
                                      -                if no_grads:
                                      -                    shift_labels = shift_labels.cpu()
                                      -                shift_labels = shift_labels.contiguous()
                                      -                labels_processed = shift_labels[0]
                                      -
                                      -                del input_ids
                                      -                del target_ids
                                      -
                                      -                for i, token_id in enumerate(labels_processed):
                                      -                    if token_id != -100:
                                      -                        log_probability = log_probabilities[0, i, token_id]
                                      -                        if no_grads:
                                      -                            log_probability = log_probability.item()
                                      -                        target_token_log_prob.append(log_probability)
                                      -                        all_token_log_prob.append(log_probabilities[0, i])
                                      -            
                                      -            # Should be equal to # of tokens - 1 to account for shift
                                      -            assert len(target_token_log_prob) == labels.size(1) - 1
                                      -            all_token_log_prob = torch.stack(all_token_log_prob, dim=0)
                                      -            assert len(target_token_log_prob) == len(all_token_log_prob)
                                      -
                                      -        if not no_grads:
                                      -            target_token_log_prob = torch.stack(target_token_log_prob)
                                      -
                                      -        if not return_all_probs:
                                      -            return target_token_log_prob
                                      -        return target_token_log_prob, all_token_log_prob
                                      -
                                      -    @torch.no_grad()
                                      -    def get_ll(self,
                                      -               text: str,
                                      -               tokens: np.ndarray=None,
                                      -               probs = None):
                                      -        """
                                      -            Get the log likelihood of each text under the base_model.
                                      -
                                      -            Args:
                                      -                text (str): The input text for which to calculate the log likelihood.
                                      -                tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
                                      -                are used instead of tokenizing the input text. Defaults to None.
                                      -                probs (list, optional): An optional list of probabilities. If provided, these probabilities
                                      -                are used instead of calling the `get_probabilities` method. Defaults to None.
                                      -        """
                                      -        all_prob = probs if probs is not None else self.get_probabilities(text, tokens=tokens)
                                      -        return -np.mean(all_prob)
                                      -
                                      -    def load_base_model_and_tokenizer(self, model_kwargs):
                                      -        """
                                      -            Load the base model and tokenizer for a given model name.
                                      -        """
                                      -        if self.device is None or self.name is None:
                                      -            raise ValueError("Please set self.device and self.name in child class")
                                      -
                                      -        if self.config.openai_config is None:
                                      -            print(f'Loading BASE model {self.name}...')
                                      -            device_map = self.device_map # if self.device_map else 'cpu'
                                      -            if "silo" in self.name or "balanced" in self.name:
                                      -                from utils.transformers.model import OpenLMforCausalLM
                                      -                model = OpenLMforCausalLM.from_pretrained(
                                      -                    self.name, **model_kwargs, device_map=self.device, cache_dir=self.cache_dir)
                                      -                # Extract the model from the model wrapper so we dont need to call model.model
                                      -            elif "llama" in self.name or "alpaca" in self.name:
                                      -                # TODO: This should be smth specified in config in case user has
                                      -                # llama is too big, gotta use device map
                                      -                model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir)
                                      -                self.device = 'cuda:1'
                                      -            elif "stablelm" in self.name.lower():  # models requiring custom code
                                      -                model = transformers.AutoModelForCausalLM.from_pretrained(
                                      -                    self.name, **model_kwargs, trust_remote_code=True, device_map=device_map, cache_dir=self.cache_dir)
                                      -            elif "olmo" in self.name.lower():
                                      -                model = transformers.AutoModelForCausalLM.from_pretrained(
                                      -                    self.name, **model_kwargs, trust_remote_code=True, cache_dir=self.cache_dir)
                                      -            else:
                                      -                model = transformers.AutoModelForCausalLM.from_pretrained(
                                      -                    self.name, **model_kwargs, device_map=device_map, cache_dir=self.cache_dir)
                                      -        else:
                                      -            model = None
                                      -
                                      -        optional_tok_kwargs = {}
                                      -        if "facebook/opt-" in self.name:
                                      -            print("Using non-fast tokenizer for OPT")
                                      -            optional_tok_kwargs['fast'] = False
                                      -        if self.config.dataset_member in ['pubmed'] or self.config.dataset_nonmember in ['pubmed']:
                                      -            optional_tok_kwargs['padding_side'] = 'left'
                                      -            self.pad_token = self.tokenizer.eos_token_id
                                      -        if "silo" in self.name or "balanced" in self.name:
                                      -            tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(
                                      -                "EleutherAI/gpt-neox-20b", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                      -        elif "datablations" in self.name:
                                      -            tokenizer = transformers.AutoTokenizer.from_pretrained(
                                      -                "gpt2", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                      -        elif "llama" in self.name or "alpaca" in self.name:
                                      -            tokenizer = transformers.LlamaTokenizer.from_pretrained(
                                      -                self.name, **optional_tok_kwargs, cache_dir=self.cache_dir)
                                      -        elif "pubmedgpt" in self.name:
                                      -            tokenizer = transformers.AutoTokenizer.from_pretrained(
                                      -                "stanford-crfm/BioMedLM", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                      -        else:
                                      -            tokenizer = transformers.AutoTokenizer.from_pretrained(
                                      -                self.name, **optional_tok_kwargs, cache_dir=self.cache_dir,
                                      -                trust_remote_code=True if "olmo" in self.name.lower() else False)
                                      -        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                                      -
                                      -        return model, tokenizer
                                      -
                                      -    def load_model_properties(self):
                                      -        """
                                      -            Load model properties, such as max length and stride.
                                      -        """
                                      -        # TODO: getting max_length of input could be more generic
                                      -        if "silo" in self.name or "balanced" in self.name:
                                      -            self.max_length = self.model.model.seq_len
                                      -        elif hasattr(self.model.config, 'max_position_embeddings'):
                                      -            self.max_length = self.model.config.max_position_embeddings
                                      -        elif hasattr(self.model.config, 'n_positions'):
                                      -            self.max_length = self.model.config.n_positions
                                      -        else:
                                      -            # Default window size
                                      -            self.max_length = 1024
                                      -        self.stride = self.max_length // 2
                                      -
                                      -
                                      -class ReferenceModel(Model):
                                      -    """
                                      -        Wrapper for reference model
                                      -    """
                                      -    def __init__(self, config: ExperimentConfig, name: str):
                                      -        super().__init__(config)
                                      -        self.device = self.config.env_config.device_aux
                                      -        self.name = name
                                      -        base_model_kwargs = {'revision': 'main'}
                                      -        if 'gpt-j' in self.name or 'neox' in self.name or 'llama' in self.name or 'alpaca' in self.name:
                                      -            base_model_kwargs.update(dict(torch_dtype=torch.float16))
                                      -        if 'gpt-j' in self.name:
                                      -            base_model_kwargs.update(dict(revision='float16'))
                                      -        if ':' in self.name:
                                      -            print("Applying ref model revision")
                                      -            # Allow them to provide revisions as part of model name, then parse accordingly
                                      -            split = self.name.split(':')
                                      -            self.name = split[0]
                                      -            base_model_kwargs.update(dict(revision=split[-1]))
                                      -        self.model, self.tokenizer = self.load_base_model_and_tokenizer(
                                      -            model_kwargs=base_model_kwargs)
                                      -        self.load_model_properties()
                                      -
                                      -    def load(self):
                                      -        """
                                      -        Load reference model noto GPU(s)
                                      -        """
                                      -        if "llama" not in self.name and "alpaca" not in self.name:
                                      -            super().load()
                                      -
                                      -    def unload(self):
                                      -        """
                                      -        Unload reference model from GPU(s)
                                      -        """
                                      -        if "llama" not in self.name and "alpaca" not in self.name:
                                      -            super().unload()
                                      -
                                      -
                                      -class QuantileReferenceModel(Model):
                                      -    """
                                      -        Wrapper for referenc model, specifically used for quantile regression
                                      -    """
                                      -    def __init__(self, config: ExperimentConfig, name: str):
                                      -        super().__init__(config)
                                      -        self.device = self.config.env_config.device_aux
                                      -        self.name = name
                                      -        self.tokenizer = AutoTokenizer.from_pretrained(
                                      -            name, use_fast=False)
                                      -        self.model = AutoModelForSequenceClassification.from_pretrained(
                                      -            name,
                                      -            num_labels=2,
                                      -            max_position_embeddings=1024)
                                      -        # Modify model's last linear layer to have only 1 output
                                      -        self.model.classifier.linear_out = nn.Linear(self.model.classifier.linear_out.in_features, 1)
                                      -        self.load_model_properties()
                                      -
                                      -
                                      -class LanguageModel(Model):
                                      -    """
                                      -        Generic LM- used most often for target model
                                      -    """
                                      -    def __init__(self, config: ExperimentConfig, **kwargs):
                                      -        super().__init__(config, **kwargs)
                                      -        self.device = self.config.env_config.device
                                      -        self.device_map = self.config.env_config.device_map
                                      -        # Use provided name (if provided)
                                      -        # Relevant for scoring-model scenario
                                      -        self.name = self.kwargs.get('name', self.config.base_model)
                                      -
                                      -        base_model_kwargs = {}
                                      -        if config.revision:
                                      -            base_model_kwargs.update(dict(revision=config.revision))
                                      -        if 'gpt-j' in self.name or 'neox' in self.name:
                                      -            base_model_kwargs.update(dict(torch_dtype=torch.float16))
                                      -        if 'gpt-j' in self.name:
                                      -            base_model_kwargs.update(dict(revision='float16'))
                                      -        self.model, self.tokenizer = self.load_base_model_and_tokenizer(
                                      -            model_kwargs=base_model_kwargs)
                                      -        self.load_model_properties()
                                      -
                                      -    @torch.no_grad()
                                      -    def get_ref(self, text: str, ref_model: ReferenceModel, tokens=None, probs=None):
                                      -        """
                                      -            Compute the loss of a given text calibrated against the text's loss under a reference model -- MIA baseline
                                      -        """
                                      -        lls = self.get_ll(text, tokens=tokens, probs=probs)
                                      -        lls_ref = ref_model.get_ll(text)
                                      -
                                      -        return lls - lls_ref
                                      -
                                      -    @torch.no_grad()
                                      -    def get_rank(self, text: str, log: bool=False):
                                      -        """
                                      -            Get the average rank of each observed token sorted by model likelihood
                                      -        """
                                      -        openai_config = self.config.openai_config
                                      -        assert openai_config is None, "get_rank not implemented for OpenAI models"
                                      -
                                      -        tokenized = self.tokenizer(text, return_tensors="pt").to(self.device)
                                      -        logits = self.model(**tokenized).logits[:,:-1]
                                      -        labels = tokenized.input_ids[:,1:]
                                      -
                                      -        # get rank of each label token in the model's likelihood ordering
                                      -        matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
                                      -
                                      -        assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"
                                      -
                                      -        ranks, timesteps = matches[:,-1], matches[:,-2]
                                      -
                                      -        # make sure we got exactly one match for each timestep in the sequence
                                      -        assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"
                                      -
                                      -        ranks = ranks.float() + 1 # convert to 1-indexed rank
                                      -        if log:
                                      -            ranks = torch.log(ranks)
                                      -
                                      -        return ranks.float().mean().item()
                                      -
                                      -    # TODO extend for longer sequences
                                      -    @torch.no_grad()
                                      -    def get_lls(self, texts: List[str], batch_size: int = 6):
                                      -        #return [self.get_ll(text) for text in texts] # -np.mean([self.get_ll(text) for text in texts])
                                      -        # tokenized = self.tokenizer(texts, return_tensors="pt", padding=True)
                                      -        # labels = tokenized.input_ids
                                      -        total_size = len(texts)
                                      -        losses = []
                                      -        for i in range(0, total_size, batch_size):
                                      -            # Delegate batches and tokenize
                                      -            batch = texts[i:i+batch_size]
                                      -            tokenized = self.tokenizer(batch, return_tensors="pt", padding=True, return_attention_mask=True)
                                      -            label_batch = tokenized.input_ids
                                      -            
                                      -            # # mask out padding tokens
                                      -            attention_mask = tokenized.attention_mask
                                      -            assert attention_mask.size() == label_batch.size()
                                      -
                                      -            needs_sliding = label_batch.size(1) > self.max_length // 2
                                      -            if not needs_sliding:
                                      -                label_batch = label_batch.to(self.device)
                                      -                attention_mask = attention_mask.to(self.device)
                                      -
                                      -            # Collect token probabilities per sample in batch
                                      -            all_prob = defaultdict(list)
                                      -            for i in range(0, label_batch.size(1), self.stride):
                                      -                begin_loc = max(i + self.stride - self.max_length, 0)
                                      -                end_loc = min(i + self.stride, label_batch.size(1))
                                      -                trg_len = end_loc - i  # may be different from stride on last loop
                                      -                input_ids = label_batch[:, begin_loc:end_loc]
                                      -                mask = attention_mask[:, begin_loc:end_loc]
                                      -                if needs_sliding:
                                      -                    input_ids = input_ids.to(self.device)
                                      -                    mask = mask.to(self.device)
                                      -                    
                                      -                target_ids = input_ids.clone()
                                      -                # Don't count padded tokens or tokens that already have computed probabilities
                                      -                target_ids[:, :-trg_len] = -100
                                      -                # target_ids[attention_mask == 0] = -100
                                      -                
                                      -                logits = self.model(input_ids, labels=target_ids, attention_mask=mask).logits.cpu()
                                      -                target_ids = target_ids.cpu()
                                      -                shift_logits = logits[..., :-1, :].contiguous()
                                      -                probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                      -                shift_labels = target_ids[..., 1:].contiguous()
                                      -
                                      -                for i, sample in enumerate(shift_labels):
                                      -                    for j, token_id in enumerate(sample):
                                      -                        if token_id != -100 and token_id != self.tokenizer.pad_token_id:
                                      -                            probability = probabilities[i, j, token_id].item()
                                      -                            all_prob[i].append(probability)
                                      -
                                      -                del input_ids
                                      -                del mask
                                      -            
                                      -            # average over each sample to get losses
                                      -            batch_losses = [-np.mean(all_prob[idx]) for idx in range(label_batch.size(0))]
                                      -            # print(batch_losses)
                                      -            losses.extend(batch_losses)
                                      -            del label_batch
                                      -            del attention_mask
                                      -        return losses #np.mean(losses)
                                      -
                                      -    def sample_from_model(self, texts: List[str], **kwargs):
                                      -        """
                                      -            Sample from base_model using ****only**** the first 30 tokens in each example as context
                                      -        """
                                      -        min_words = kwargs.get('min_words', 55)
                                      -        max_words = kwargs.get('max_words', 200)
                                      -        prompt_tokens = kwargs.get('prompt_tokens', 30)
                                      -
                                      -        # encode each text as a list of token ids
                                      -        if self.config.dataset_member == 'pubmed':
                                      -            texts = [t[:t.index(SEPARATOR)] for t in texts]
                                      -            all_encoded = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
                                      -        else:
                                      -            all_encoded = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
                                      -            all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}
                                      -
                                      -        decoded = ['' for _ in range(len(texts))]
                                      -
                                      -        # sample from the model until we get a sample with at least min_words words for each example
                                      -        # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works
                                      -        tries = 0
                                      -        while (m := min(len(x.split()) for x in decoded)) < min_words and tries <  self.config.neighborhood_config.top_p:
                                      -            if tries != 0:
                                      -                print()
                                      -                print(f"min words: {m}, needed {min_words}, regenerating (try {tries})")
                                      -
                                      -            sampling_kwargs = {}
                                      -            if self.config.do_top_p:
                                      -                sampling_kwargs['top_p'] = self.config.top_p
                                      -            elif self.config.do_top_k:
                                      -                sampling_kwargs['top_k'] = self.config.top_k
                                      -            #min_length = 50 if config.dataset_member in ['pubmed'] else 150
                                      -
                                      -            #outputs = base_model.generate(**all_encoded, min_length=min_length, max_length=max_length, do_sample=True, **sampling_kwargs, pad_token_id=base_tokenizer.eos_token_id, eos_token_id=base_tokenizer.eos_token_id)
                                      -            #removed minlen and attention mask min_length=min_length, max_length=200, do_sample=True,pad_token_id=base_tokenizer.eos_token_id,
                                      -            outputs = self.model.generate(**all_encoded, min_length=min_words*2, max_length=max_words*3,  **sampling_kwargs,  eos_token_id=self.tokenizer.eos_token_id)
                                      -            decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                                      -            tries += 1
                                      -
                                      -        return decoded
                                      -
                                      -    @torch.no_grad()
                                      -    def get_entropy(self, text: str):
                                      -        """
                                      -            Get average entropy of each token in the text
                                      -        """
                                      -        # raise NotImplementedError("get_entropy not implemented for OpenAI models")
                                      -        
                                      -        tokenized = self.tokenizer(text, return_tensors="pt").to(self.device)
                                      -        logits = self.model(**tokenized).logits[:,:-1]
                                      -        neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
                                      -        return -neg_entropy.sum(-1).mean().item()
                                      -    
                                      -    @torch.no_grad()
                                      -    def get_max_norm(self, text: str, context_len=None, tk_freq_map=None):
                                      -        # TODO: update like other attacks
                                      -        tokenized = self.tokenizer(
                                      -            text, return_tensors="pt").to(self.device)
                                      -        labels = tokenized.input_ids
                                      -
                                      -        max_length = context_len if context_len is not None else self.max_length
                                      -        stride = max_length // 2 #self.stride
                                      -        all_prob = []
                                      -        for i in range(0, labels.size(1), stride):
                                      -            begin_loc = max(i + stride - max_length, 0)
                                      -            end_loc = min(i + stride, labels.size(1))
                                      -            trg_len = end_loc - i  # may be different from stride on last loop
                                      -            input_ids = labels[:, begin_loc:end_loc]
                                      -            target_ids = input_ids.clone()
                                      -            target_ids[:, :-trg_len] = -100
                                      -
                                      -            outputs = self.model(input_ids, labels=target_ids)
                                      -            logits = outputs.logits
                                      -            # Shift so that tokens < n predict n
                                      -            # print(logits.shape)
                                      -            shift_logits = logits[..., :-1, :].contiguous()
                                      -            # shift_logits = torch.transpose(shift_logits, 1, 2)
                                      -            probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                      -            shift_labels = target_ids[..., 1:].contiguous()
                                      -            labels_processed = shift_labels[0]
                                      -
                                      -            for i, token_id in enumerate(labels_processed):
                                      -                if token_id != -100:
                                      -                    probability = probabilities[0, i, token_id].item()
                                      -                    max_tk_prob = torch.max(probabilities[0, i]).item()
                                      -                    tk_weight = max(tk_freq_map[token_id.item()], 1) / sum(tk_freq_map.values()) if tk_freq_map is not None else 1
                                      -                    if tk_weight == 0:
                                      -                        print("0 count token", token_id.item())
                                      -                    tk_norm = tk_weight
                                      -                    all_prob.append((1 - (max_tk_prob - probability)) / tk_norm)
                                      -
                                      -        # Should be equal to # of tokens - 1 to account for shift
                                      -        assert len(all_prob) == labels.size(1) - 1
                                      -        return -np.mean(all_prob)
                                      -
                                      -
                                      -class OpenAI_APIModel(LanguageModel):
                                      -    """
                                      -        Wrapper for OpenAI API calls
                                      -    """
                                      -    def __init__(self, config: ExperimentConfig, **kwargs):
                                      -        super().__init__(config, **kwargs)
                                      -        self.model = None
                                      -        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2', cache_dir=self.cache_dir)
                                      -        self.API_TOKEN_COUNTER = 0
                                      -    
                                      -    @property
                                      -    def api_calls(self):
                                      -        """
                                      -            Get the number of tokens used in API calls
                                      -        """
                                      -        return self.API_TOKEN_COUNTER
                                      -
                                      -    @torch.no_grad()
                                      -    def get_ll(self, text: str):
                                      -        """
                                      -            Get the log likelihood of each text under the base_model
                                      -        """
                                      -        openai_config = self.config.openai_config
                                      -
                                      -        kwargs = {"engine": openai_config.model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0}
                                      -        r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
                                      -        result = r['choices'][0]
                                      -        tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]
                                      -
                                      -        assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
                                      -
                                      -        return np.mean(logprobs)
                                      -
                                      -    @torch.no_grad()
                                      -    def get_ref(self, text: str, ref_model: ReferenceModel):
                                      -        """
                                      -            Get the  likelihood ratio of each text under the base_model -- MIA baseline
                                      -        """
                                      -        raise NotImplementedError("OpenAI model not implemented for LIRA")
                                      -        openai_config = self.config.openai_config
                                      -        kwargs = {"engine": openai_config.model, "temperature": 0,
                                      -                    "max_tokens": 0, "echo": True, "logprobs": 0}
                                      -        r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
                                      -        result = r['choices'][0]
                                      -        tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]
                                      -
                                      -        assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
                                      -
                                      -        return np.mean(logprobs)
                                      -
                                      -    def get_lls(self, texts: str):
                                      -
                                      -        # use GPT2_TOKENIZER to get total number of tokens
                                      -        total_tokens = sum(len(self.tokenizer.encode(text)) for text in texts)
                                      -        self.API_TOKEN_COUNTER += total_tokens * 2  # multiply by two because OpenAI double-counts echo_prompt tokens
                                      -
                                      -        pool = ThreadPool(self.config.batch_size)
                                      -        return pool.map(self.get_ll, texts)
                                      -
                                      -    def _openai_sample(self, p: str):
                                      -        openai_config = self.config.openai_config
                                      -        if self.config.dataset_member != 'pubmed':  # keep Answer: prefix for pubmed
                                      -            p = drop_last_word(p)
                                      -
                                      -        # sample from the openai model
                                      -        kwargs = { "engine": openai_config.model, "max_tokens": 200 }
                                      -        if self.config.do_top_p:
                                      -            kwargs['top_p'] = self.config.top_p
                                      -    
                                      -        r = openai.Completion.create(prompt=f"{p}", **kwargs)
                                      -        return p + r['choices'][0].text
                                      -
                                      -
                                      -    def sample_from_model(self, texts: List[str], **kwargs):
                                      -        """
                                      -            Sample from base_model using ****only**** the first 30 tokens in each example as context
                                      -        """
                                      -        prompt_tokens = kwargs.get('prompt_tokens', 30)
                                      -        base_tokenizer = kwargs.get('base_tokenizer', None)
                                      -        if base_tokenizer is None:
                                      -            raise ValueError("Please provide base_tokenizer")
                                      -
                                      -        # encode each text as a list of token ids
                                      -        if self.config.dataset_member == 'pubmed':
                                      -            texts = [t[:t.index(SEPARATOR)] for t in texts]
                                      -            all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(self.device)
                                      -        else:
                                      -            all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(self.device)
                                      -            all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}
                                      -
                                      -        # decode the prefixes back into text
                                      -        prefixes = base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True)
                                      -        pool = ThreadPool(self.config.batch_size)
                                      -
                                      -        decoded = pool.map(self._openai_sample, prefixes)
                                      -
                                      -        # count total number of tokens with GPT2_TOKENIZER
                                      -        total_tokens = sum(len(self.tokenizer.encode(x)) for x in decoded)
                                      -        self.API_TOKEN_COUNTER += total_tokens
                                      -
                                      -        return decoded
                                      -    
                                      -    @torch.no_grad()
                                      -    def get_entropy(self, text: str):
                                      -        """
                                      -            Get average entropy of each token in the text
                                      -        """
                                      -        raise NotImplementedError("get_entropy not implemented for OpenAI models")
                                      -
                                      @@ -901,21 +275,6 @@

                                      Subclasses

                                      -

                                      Class variables

                                      -
                                      -
                                      var call_super_init : bool
                                      -
                                      -
                                      -
                                      -
                                      var dump_patches : bool
                                      -
                                      -
                                      -
                                      -
                                      var training : bool
                                      -
                                      -
                                      -
                                      -

                                      Methods

                                      @@ -923,262 +282,42 @@

                                      Methods

                                      Get average entropy of each token in the text

                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_entropy(self, text: str):
                                      -    """
                                      -        Get average entropy of each token in the text
                                      -    """
                                      -    # raise NotImplementedError("get_entropy not implemented for OpenAI models")
                                      -    
                                      -    tokenized = self.tokenizer(text, return_tensors="pt").to(self.device)
                                      -    logits = self.model(**tokenized).logits[:,:-1]
                                      -    neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
                                      -    return -neg_entropy.sum(-1).mean().item()
                                      -
                                      def get_lls(self, texts: List[str], batch_size: int = 6)
                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_lls(self, texts: List[str], batch_size: int = 6):
                                      -    #return [self.get_ll(text) for text in texts] # -np.mean([self.get_ll(text) for text in texts])
                                      -    # tokenized = self.tokenizer(texts, return_tensors="pt", padding=True)
                                      -    # labels = tokenized.input_ids
                                      -    total_size = len(texts)
                                      -    losses = []
                                      -    for i in range(0, total_size, batch_size):
                                      -        # Delegate batches and tokenize
                                      -        batch = texts[i:i+batch_size]
                                      -        tokenized = self.tokenizer(batch, return_tensors="pt", padding=True, return_attention_mask=True)
                                      -        label_batch = tokenized.input_ids
                                      -        
                                      -        # # mask out padding tokens
                                      -        attention_mask = tokenized.attention_mask
                                      -        assert attention_mask.size() == label_batch.size()
                                      -
                                      -        needs_sliding = label_batch.size(1) > self.max_length // 2
                                      -        if not needs_sliding:
                                      -            label_batch = label_batch.to(self.device)
                                      -            attention_mask = attention_mask.to(self.device)
                                      -
                                      -        # Collect token probabilities per sample in batch
                                      -        all_prob = defaultdict(list)
                                      -        for i in range(0, label_batch.size(1), self.stride):
                                      -            begin_loc = max(i + self.stride - self.max_length, 0)
                                      -            end_loc = min(i + self.stride, label_batch.size(1))
                                      -            trg_len = end_loc - i  # may be different from stride on last loop
                                      -            input_ids = label_batch[:, begin_loc:end_loc]
                                      -            mask = attention_mask[:, begin_loc:end_loc]
                                      -            if needs_sliding:
                                      -                input_ids = input_ids.to(self.device)
                                      -                mask = mask.to(self.device)
                                      -                
                                      -            target_ids = input_ids.clone()
                                      -            # Don't count padded tokens or tokens that already have computed probabilities
                                      -            target_ids[:, :-trg_len] = -100
                                      -            # target_ids[attention_mask == 0] = -100
                                      -            
                                      -            logits = self.model(input_ids, labels=target_ids, attention_mask=mask).logits.cpu()
                                      -            target_ids = target_ids.cpu()
                                      -            shift_logits = logits[..., :-1, :].contiguous()
                                      -            probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                      -            shift_labels = target_ids[..., 1:].contiguous()
                                      -
                                      -            for i, sample in enumerate(shift_labels):
                                      -                for j, token_id in enumerate(sample):
                                      -                    if token_id != -100 and token_id != self.tokenizer.pad_token_id:
                                      -                        probability = probabilities[i, j, token_id].item()
                                      -                        all_prob[i].append(probability)
                                      -
                                      -            del input_ids
                                      -            del mask
                                      -        
                                      -        # average over each sample to get losses
                                      -        batch_losses = [-np.mean(all_prob[idx]) for idx in range(label_batch.size(0))]
                                      -        # print(batch_losses)
                                      -        losses.extend(batch_losses)
                                      -        del label_batch
                                      -        del attention_mask
                                      -    return losses #np.mean(losses)
                                      -
                                      def get_max_norm(self, text: str, context_len=None, tk_freq_map=None)
                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_max_norm(self, text: str, context_len=None, tk_freq_map=None):
                                      -    # TODO: update like other attacks
                                      -    tokenized = self.tokenizer(
                                      -        text, return_tensors="pt").to(self.device)
                                      -    labels = tokenized.input_ids
                                      -
                                      -    max_length = context_len if context_len is not None else self.max_length
                                      -    stride = max_length // 2 #self.stride
                                      -    all_prob = []
                                      -    for i in range(0, labels.size(1), stride):
                                      -        begin_loc = max(i + stride - max_length, 0)
                                      -        end_loc = min(i + stride, labels.size(1))
                                      -        trg_len = end_loc - i  # may be different from stride on last loop
                                      -        input_ids = labels[:, begin_loc:end_loc]
                                      -        target_ids = input_ids.clone()
                                      -        target_ids[:, :-trg_len] = -100
                                      -
                                      -        outputs = self.model(input_ids, labels=target_ids)
                                      -        logits = outputs.logits
                                      -        # Shift so that tokens < n predict n
                                      -        # print(logits.shape)
                                      -        shift_logits = logits[..., :-1, :].contiguous()
                                      -        # shift_logits = torch.transpose(shift_logits, 1, 2)
                                      -        probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                      -        shift_labels = target_ids[..., 1:].contiguous()
                                      -        labels_processed = shift_labels[0]
                                      -
                                      -        for i, token_id in enumerate(labels_processed):
                                      -            if token_id != -100:
                                      -                probability = probabilities[0, i, token_id].item()
                                      -                max_tk_prob = torch.max(probabilities[0, i]).item()
                                      -                tk_weight = max(tk_freq_map[token_id.item()], 1) / sum(tk_freq_map.values()) if tk_freq_map is not None else 1
                                      -                if tk_weight == 0:
                                      -                    print("0 count token", token_id.item())
                                      -                tk_norm = tk_weight
                                      -                all_prob.append((1 - (max_tk_prob - probability)) / tk_norm)
                                      -
                                      -    # Should be equal to # of tokens - 1 to account for shift
                                      -    assert len(all_prob) == labels.size(1) - 1
                                      -    return -np.mean(all_prob)
                                      -
                                      def get_rank(self, text: str, log: bool = False)

                                      Get the average rank of each observed token sorted by model likelihood

                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_rank(self, text: str, log: bool=False):
                                      -    """
                                      -        Get the average rank of each observed token sorted by model likelihood
                                      -    """
                                      -    openai_config = self.config.openai_config
                                      -    assert openai_config is None, "get_rank not implemented for OpenAI models"
                                      -
                                      -    tokenized = self.tokenizer(text, return_tensors="pt").to(self.device)
                                      -    logits = self.model(**tokenized).logits[:,:-1]
                                      -    labels = tokenized.input_ids[:,1:]
                                      -
                                      -    # get rank of each label token in the model's likelihood ordering
                                      -    matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
                                      -
                                      -    assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"
                                      -
                                      -    ranks, timesteps = matches[:,-1], matches[:,-2]
                                      -
                                      -    # make sure we got exactly one match for each timestep in the sequence
                                      -    assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"
                                      -
                                      -    ranks = ranks.float() + 1 # convert to 1-indexed rank
                                      -    if log:
                                      -        ranks = torch.log(ranks)
                                      -
                                      -    return ranks.float().mean().item()
                                      -
                                      def get_ref(self, text: str, ref_model: ReferenceModel, tokens=None, probs=None)

                                      Compute the loss of a given text calibrated against the text's loss under a reference model – MIA baseline

                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_ref(self, text: str, ref_model: ReferenceModel, tokens=None, probs=None):
                                      -    """
                                      -        Compute the loss of a given text calibrated against the text's loss under a reference model -- MIA baseline
                                      -    """
                                      -    lls = self.get_ll(text, tokens=tokens, probs=probs)
                                      -    lls_ref = ref_model.get_ll(text)
                                      -
                                      -    return lls - lls_ref
                                      -
                                      def sample_from_model(self, texts: List[str], **kwargs)

                                      Sample from base_model using *only* the first 30 tokens in each example as context

                                      -
                                      - -Expand source code - -
                                      def sample_from_model(self, texts: List[str], **kwargs):
                                      -    """
                                      -        Sample from base_model using ****only**** the first 30 tokens in each example as context
                                      -    """
                                      -    min_words = kwargs.get('min_words', 55)
                                      -    max_words = kwargs.get('max_words', 200)
                                      -    prompt_tokens = kwargs.get('prompt_tokens', 30)
                                      -
                                      -    # encode each text as a list of token ids
                                      -    if self.config.dataset_member == 'pubmed':
                                      -        texts = [t[:t.index(SEPARATOR)] for t in texts]
                                      -        all_encoded = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
                                      -    else:
                                      -        all_encoded = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
                                      -        all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}
                                      -
                                      -    decoded = ['' for _ in range(len(texts))]
                                      -
                                      -    # sample from the model until we get a sample with at least min_words words for each example
                                      -    # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works
                                      -    tries = 0
                                      -    while (m := min(len(x.split()) for x in decoded)) < min_words and tries <  self.config.neighborhood_config.top_p:
                                      -        if tries != 0:
                                      -            print()
                                      -            print(f"min words: {m}, needed {min_words}, regenerating (try {tries})")
                                      -
                                      -        sampling_kwargs = {}
                                      -        if self.config.do_top_p:
                                      -            sampling_kwargs['top_p'] = self.config.top_p
                                      -        elif self.config.do_top_k:
                                      -            sampling_kwargs['top_k'] = self.config.top_k
                                      -        #min_length = 50 if config.dataset_member in ['pubmed'] else 150
                                      -
                                      -        #outputs = base_model.generate(**all_encoded, min_length=min_length, max_length=max_length, do_sample=True, **sampling_kwargs, pad_token_id=base_tokenizer.eos_token_id, eos_token_id=base_tokenizer.eos_token_id)
                                      -        #removed minlen and attention mask min_length=min_length, max_length=200, do_sample=True,pad_token_id=base_tokenizer.eos_token_id,
                                      -        outputs = self.model.generate(**all_encoded, min_length=min_words*2, max_length=max_words*3,  **sampling_kwargs,  eos_token_id=self.tokenizer.eos_token_id)
                                      -        decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                                      -        tries += 1
                                      -
                                      -    return decoded
                                      -

                                      Inherited members

                                      • Model: -

                                        Class variables

                                        -
                                        -
                                        var call_super_init : bool
                                        -
                                        -
                                        -
                                        -
                                        var dump_patches : bool
                                        -
                                        -
                                        -
                                        -
                                        var training : bool
                                        -
                                        -
                                        -
                                        -

                                        Methods

                                        -
                                        -def forward(self, *input: Any) ‑> None -
                                        -
                                        -

                                        Define the computation performed at every call.

                                        -

                                        Should be overridden by all subclasses.

                                        -
                                        -

                                        Note

                                        -

                                        Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

                                        -
                                        -
                                        - -Expand source code - -
                                        def _forward_unimplemented(self, *input: Any) -> None:
                                        -    r"""Define the computation performed at every call.
                                        -
                                        -    Should be overridden by all subclasses.
                                        -
                                        -    .. note::
                                        -        Although the recipe for forward pass needs to be defined within
                                        -        this function, one should call the :class:`Module` instance afterwards
                                        -        instead of this since the former takes care of running the
                                        -        registered hooks while the latter silently ignores them.
                                        -    """
                                        -    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
                                        -
                                        -
                                        def get_ll(self, text: str, tokens: numpy.ndarray = None, probs=None)
                                        @@ -1493,28 +586,6 @@

                                        Args

                                        An optional list of probabilities. If provided, these probabilities

                                        are used instead of calling the get_probabilities method. Defaults to None.

                                        -
                                        - -Expand source code - -
                                        @torch.no_grad()
                                        -def get_ll(self,
                                        -           text: str,
                                        -           tokens: np.ndarray=None,
                                        -           probs = None):
                                        -    """
                                        -        Get the log likelihood of each text under the base_model.
                                        -
                                        -        Args:
                                        -            text (str): The input text for which to calculate the log likelihood.
                                        -            tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
                                        -            are used instead of tokenizing the input text. Defaults to None.
                                        -            probs (list, optional): An optional list of probabilities. If provided, these probabilities
                                        -            are used instead of calling the `get_probabilities` method. Defaults to None.
                                        -    """
                                        -    all_prob = probs if probs is not None else self.get_probabilities(text, tokens=tokens)
                                        -    return -np.mean(all_prob)
                                        -
                                        def get_probabilities(self, text: str, tokens: numpy.ndarray = None, no_grads: bool = True, return_all_probs: bool = False) @@ -1539,244 +610,36 @@

                                        Returns

                                        list
                                        A list of probabilities.
                                        -
                                        - -Expand source code - -
                                        def get_probabilities(self,
                                        -                      text: str,
                                        -                      tokens: np.ndarray = None,
                                        -                      no_grads: bool = True,
                                        -                      return_all_probs: bool = False):
                                        -    """
                                        -        Get the probabilities or log-softmaxed logits for a text under the current model.
                                        -        Args:
                                        -            text (str): The input text for which to calculate probabilities.
                                        -            tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
                                        -            are used instead of tokenizing the input text. Defaults to None.
                                        -
                                        -        Raises:
                                        -            ValueError: If the device or name attributes of the instance are not set.
                                        -
                                        -        Returns:
                                        -            list: A list of probabilities.
                                        -    """
                                        -    with torch.set_grad_enabled(not no_grads):
                                        -        if self.device is None or self.name is None:
                                        -            raise ValueError("Please set self.device and self.name in child class")
                                        -
                                        -        if tokens is not None:
                                        -            labels = torch.from_numpy(tokens.astype(np.int64)).type(torch.LongTensor)
                                        -            if labels.shape[0] != 1:
                                        -                # expand first dimension
                                        -                labels = labels.unsqueeze(0)
                                        -        else:
                                        -            tokenized = self.tokenizer(
                                        -                text, return_tensors="pt")
                                        -            labels = tokenized.input_ids
                                        -
                                        -        target_token_log_prob = []
                                        -        all_token_log_prob = []
                                        -        for i in range(0, labels.size(1), self.stride):
                                        -            begin_loc = max(i + self.stride - self.max_length, 0)
                                        -            end_loc = min(i + self.stride, labels.size(1))
                                        -            trg_len = end_loc - i  # may be different from stride on last loop
                                        -            input_ids = labels[:, begin_loc:end_loc].to(self.device)
                                        -            target_ids = input_ids.clone()
                                        -            target_ids[:, :-trg_len] = -100
                                        -
                                        -            logits = self.model(input_ids, labels=target_ids).logits
                                        -            if no_grads:
                                        -                logits = logits.cpu()
                                        -            shift_logits = logits[..., :-1, :].contiguous()
                                        -            log_probabilities = torch.nn.functional.log_softmax(shift_logits, dim=-1)
                                        -            shift_labels = target_ids[..., 1:]
                                        -            if no_grads:
                                        -                shift_labels = shift_labels.cpu()
                                        -            shift_labels = shift_labels.contiguous()
                                        -            labels_processed = shift_labels[0]
                                        -
                                        -            del input_ids
                                        -            del target_ids
                                        -
                                        -            for i, token_id in enumerate(labels_processed):
                                        -                if token_id != -100:
                                        -                    log_probability = log_probabilities[0, i, token_id]
                                        -                    if no_grads:
                                        -                        log_probability = log_probability.item()
                                        -                    target_token_log_prob.append(log_probability)
                                        -                    all_token_log_prob.append(log_probabilities[0, i])
                                        -        
                                        -        # Should be equal to # of tokens - 1 to account for shift
                                        -        assert len(target_token_log_prob) == labels.size(1) - 1
                                        -        all_token_log_prob = torch.stack(all_token_log_prob, dim=0)
                                        -        assert len(target_token_log_prob) == len(all_token_log_prob)
                                        -
                                        -    if not no_grads:
                                        -        target_token_log_prob = torch.stack(target_token_log_prob)
                                        -
                                        -    if not return_all_probs:
                                        -        return target_token_log_prob
                                        -    return target_token_log_prob, all_token_log_prob
                                        -
                                        def load(self)

                                        Load model onto GPU (and compile, if requested) if not already loaded with device map.

                                        -
                                        - -Expand source code - -
                                        def load(self):
                                        -    """
                                        -        Load model onto GPU (and compile, if requested) if not already loaded with device map.
                                        -    """
                                        -    if not self.device_map:
                                        -        start = time.time()
                                        -        try:
                                        -            self.model.cpu()
                                        -        except NameError:
                                        -            pass
                                        -        if self.config.openai_config is None:
                                        -            self.model.to(self.device, non_blocking=True)
                                        -        if self.config.env_config.compile:
                                        -            torch.compile(self.model)
                                        -        print(f'DONE ({time.time() - start:.2f}s)')
                                        -
                                        def load_base_model_and_tokenizer(self, model_kwargs)

                                        Load the base model and tokenizer for a given model name.

                                        -
                                        - -Expand source code - -
                                        def load_base_model_and_tokenizer(self, model_kwargs):
                                        -    """
                                        -        Load the base model and tokenizer for a given model name.
                                        -    """
                                        -    if self.device is None or self.name is None:
                                        -        raise ValueError("Please set self.device and self.name in child class")
                                        -
                                        -    if self.config.openai_config is None:
                                        -        print(f'Loading BASE model {self.name}...')
                                        -        device_map = self.device_map # if self.device_map else 'cpu'
                                        -        if "silo" in self.name or "balanced" in self.name:
                                        -            from utils.transformers.model import OpenLMforCausalLM
                                        -            model = OpenLMforCausalLM.from_pretrained(
                                        -                self.name, **model_kwargs, device_map=self.device, cache_dir=self.cache_dir)
                                        -            # Extract the model from the model wrapper so we dont need to call model.model
                                        -        elif "llama" in self.name or "alpaca" in self.name:
                                        -            # TODO: This should be smth specified in config in case user has
                                        -            # llama is too big, gotta use device map
                                        -            model = transformers.AutoModelForCausalLM.from_pretrained(self.name, **model_kwargs, device_map="balanced_low_0", cache_dir=self.cache_dir)
                                        -            self.device = 'cuda:1'
                                        -        elif "stablelm" in self.name.lower():  # models requiring custom code
                                        -            model = transformers.AutoModelForCausalLM.from_pretrained(
                                        -                self.name, **model_kwargs, trust_remote_code=True, device_map=device_map, cache_dir=self.cache_dir)
                                        -        elif "olmo" in self.name.lower():
                                        -            model = transformers.AutoModelForCausalLM.from_pretrained(
                                        -                self.name, **model_kwargs, trust_remote_code=True, cache_dir=self.cache_dir)
                                        -        else:
                                        -            model = transformers.AutoModelForCausalLM.from_pretrained(
                                        -                self.name, **model_kwargs, device_map=device_map, cache_dir=self.cache_dir)
                                        -    else:
                                        -        model = None
                                        -
                                        -    optional_tok_kwargs = {}
                                        -    if "facebook/opt-" in self.name:
                                        -        print("Using non-fast tokenizer for OPT")
                                        -        optional_tok_kwargs['fast'] = False
                                        -    if self.config.dataset_member in ['pubmed'] or self.config.dataset_nonmember in ['pubmed']:
                                        -        optional_tok_kwargs['padding_side'] = 'left'
                                        -        self.pad_token = self.tokenizer.eos_token_id
                                        -    if "silo" in self.name or "balanced" in self.name:
                                        -        tokenizer = transformers.GPTNeoXTokenizerFast.from_pretrained(
                                        -            "EleutherAI/gpt-neox-20b", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                        -    elif "datablations" in self.name:
                                        -        tokenizer = transformers.AutoTokenizer.from_pretrained(
                                        -            "gpt2", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                        -    elif "llama" in self.name or "alpaca" in self.name:
                                        -        tokenizer = transformers.LlamaTokenizer.from_pretrained(
                                        -            self.name, **optional_tok_kwargs, cache_dir=self.cache_dir)
                                        -    elif "pubmedgpt" in self.name:
                                        -        tokenizer = transformers.AutoTokenizer.from_pretrained(
                                        -            "stanford-crfm/BioMedLM", **optional_tok_kwargs, cache_dir=self.cache_dir)
                                        -    else:
                                        -        tokenizer = transformers.AutoTokenizer.from_pretrained(
                                        -            self.name, **optional_tok_kwargs, cache_dir=self.cache_dir,
                                        -            trust_remote_code=True if "olmo" in self.name.lower() else False)
                                        -    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                                        -
                                        -    return model, tokenizer
                                        -
                                        def load_model_properties(self)

                                        Load model properties, such as max length and stride.

                                        -
                                        - -Expand source code - -
                                        def load_model_properties(self):
                                        -    """
                                        -        Load model properties, such as max length and stride.
                                        -    """
                                        -    # TODO: getting max_length of input could be more generic
                                        -    if "silo" in self.name or "balanced" in self.name:
                                        -        self.max_length = self.model.model.seq_len
                                        -    elif hasattr(self.model.config, 'max_position_embeddings'):
                                        -        self.max_length = self.model.config.max_position_embeddings
                                        -    elif hasattr(self.model.config, 'n_positions'):
                                        -        self.max_length = self.model.config.n_positions
                                        -    else:
                                        -        # Default window size
                                        -        self.max_length = 1024
                                        -    self.stride = self.max_length // 2
                                        -
                                        def to(self, device)

                                        Shift model to a particular device.

                                        -
                                        - -Expand source code - -
                                        def to(self, device):
                                        -    """
                                        -        Shift model to a particular device.
                                        -    """
                                        -    self.model.to(device, non_blocking=True)
                                        -
                                        def unload(self)

                                        Unload model from GPU

                                        -
                                        - -Expand source code - -
                                        def unload(self):
                                        -    """
                                        -        Unload model from GPU
                                        -    """
                                        -    start = time.time()
                                        -    try:
                                        -        self.model.cpu()
                                        -    except NameError:
                                        -        pass
                                        -    print(f'DONE ({time.time() - start:.2f}s)')
                                        -
                                        @@ -1906,24 +769,9 @@

                                        Ancestors

                                      • Model
                                      • torch.nn.modules.module.Module
                                      -

                                      Class variables

                                      -
                                      -
                                      var call_super_init : bool
                                      -
                                      -
                                      -
                                      -
                                      var dump_patches : bool
                                      -
                                      -
                                      -
                                      -
                                      var training : bool
                                      -
                                      -
                                      -
                                      -

                                      Instance variables

                                      -
                                      var api_calls
                                      +
                                      prop api_calls

                                      Get the number of tokens used in API calls

                                      @@ -1946,45 +794,12 @@

                                      Methods

                                      Get the log likelihood of each text under the base_model

                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_ll(self, text: str):
                                      -    """
                                      -        Get the log likelihood of each text under the base_model
                                      -    """
                                      -    openai_config = self.config.openai_config
                                      -
                                      -    kwargs = {"engine": openai_config.model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0}
                                      -    r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
                                      -    result = r['choices'][0]
                                      -    tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]
                                      -
                                      -    assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
                                      -
                                      -    return np.mean(logprobs)
                                      -
                                      def get_lls(self, texts: str)
                                      -
                                      - -Expand source code - -
                                      def get_lls(self, texts: str):
                                      -
                                      -    # use GPT2_TOKENIZER to get total number of tokens
                                      -    total_tokens = sum(len(self.tokenizer.encode(text)) for text in texts)
                                      -    self.API_TOKEN_COUNTER += total_tokens * 2  # multiply by two because OpenAI double-counts echo_prompt tokens
                                      -
                                      -    pool = ThreadPool(self.config.batch_size)
                                      -    return pool.map(self.get_ll, texts)
                                      -
                                      def get_ref(self, text: str, ref_model: ReferenceModel) @@ -1992,34 +807,12 @@

                                      Methods

                                      Get the likelihood ratio of each text under the base_model – MIA baseline

                                      -
                                      - -Expand source code - -
                                      @torch.no_grad()
                                      -def get_ref(self, text: str, ref_model: ReferenceModel):
                                      -    """
                                      -        Get the  likelihood ratio of each text under the base_model -- MIA baseline
                                      -    """
                                      -    raise NotImplementedError("OpenAI model not implemented for LIRA")
                                      -    openai_config = self.config.openai_config
                                      -    kwargs = {"engine": openai_config.model, "temperature": 0,
                                      -                "max_tokens": 0, "echo": True, "logprobs": 0}
                                      -    r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
                                      -    result = r['choices'][0]
                                      -    tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]
                                      -
                                      -    assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
                                      -
                                      -    return np.mean(logprobs)
                                      -

                                      Inherited members

                                      • LanguageModel: -

                                        Class variables

                                        -
                                        -
                                        var call_super_init : bool
                                        -
                                        -
                                        -
                                        -
                                        var dump_patches : bool
                                        -
                                        -
                                        -
                                        -
                                        var training : bool
                                        -
                                        -
                                        -
                                        -

                                        Inherited members

                                        • Model: -

                                          Class variables

                                          -
                                          -
                                          var call_super_init : bool
                                          -
                                          -
                                          -
                                          -
                                          var dump_patches : bool
                                          -
                                          -
                                          -
                                          -
                                          var training : bool
                                          -
                                          -
                                          -
                                          -

                                          Methods

                                          @@ -2173,41 +935,18 @@

                                          Methods

                                          Load reference model noto GPU(s)

                                          -
                                          - -Expand source code - -
                                          def load(self):
                                          -    """
                                          -    Load reference model noto GPU(s)
                                          -    """
                                          -    if "llama" not in self.name and "alpaca" not in self.name:
                                          -        super().load()
                                          -
                                          def unload(self)

                                          Unload reference model from GPU(s)

                                          -
                                          - -Expand source code - -
                                          def unload(self):
                                          -    """
                                          -    Unload reference model from GPU(s)
                                          -    """
                                          -    if "llama" not in self.name and "alpaca" not in self.name:
                                          -        super().unload()
                                          -

                                          Inherited members

                                          • Model:
                                              -
                                            • forward
                                            • get_ll
                                            • get_probabilities
                                            • load_base_model_and_tokenizer
                                            • @@ -2226,7 +965,6 @@

                                              Inherited members

                                              MIMIR -

                                              Index

                                                @@ -2241,60 +979,42 @@

                                                Index

                                              • LanguageModel

                                              • Model

                                              • OpenAI_APIModel

                                                -
                                                  +
                                                • QuantileReferenceModel

                                                  -
                                                • ReferenceModel

                                                • @@ -2304,7 +1024,7 @@

                                                  -

                                                  Generated by pdoc 0.10.0.

                                                  +

                                                  Generated by pdoc 0.11.1.

                                                  - \ No newline at end of file + diff --git a/docs/plot_utils.html b/docs/plot_utils.html index 0edaa63..6eec9ed 100644 --- a/docs/plot_utils.html +++ b/docs/plot_utils.html @@ -2,18 +2,21 @@ - - + + mimir.plot_utils API documentation - - - - - - + + + + + + - - + +
                                                  @@ -23,139 +26,6 @@

                                                  Module mimir.plot_utils

                                                  Utilities related to plotting.

                                                  -
                                                  - -Expand source code - -
                                                  """
                                                  -    Utilities related to plotting.
                                                  -"""
                                                  -
                                                  -import matplotlib.pyplot as plt
                                                  -# Set high DPI
                                                  -plt.rcParams['figure.dpi'] = 300
                                                  -
                                                  -
                                                  -# 15 colorblind-friendly colors
                                                  -COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
                                                  -            "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
                                                  -            "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]
                                                  -
                                                  -
                                                  -def save_roc_curves(experiments, save_folder, model_name, neighbor_model_name: str = None):
                                                  -    """
                                                  -        Save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment, color in zip(experiments, COLORS):
                                                  -        metrics = experiment["metrics"]
                                                  -        plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
                                                  -        # print roc_auc for this experiment
                                                  -        print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
                                                  -    plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--')
                                                  -    plt.xlim([0.0, 1.0])
                                                  -    plt.ylim([0.0, 1.05])
                                                  -    plt.xlabel('False Positive Rate')
                                                  -    plt.ylabel('True Positive Rate')
                                                  -    if neighbor_model_name:
                                                  -        plt.title(f'ROC Curves ({model_name} - {neighbor_model_name})')
                                                  -    else:
                                                  -        plt.title(f'ROC Curves ({model_name})')
                                                  -    plt.legend(loc="lower right", fontsize=6)
                                                  -    plt.savefig(f"{save_folder}/roc_curves.png")
                                                  -
                                                  -    # Also plot ROC curves for low TPR-FPR region
                                                  -    plt.clf()
                                                  -    for experiment, color in zip(experiments, COLORS):
                                                  -        metrics = experiment["metrics"]
                                                  -        plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
                                                  -    plt.xscale('log')
                                                  -    plt.yscale('log')
                                                  -    plt.xlim(1e-5, 1)
                                                  -    plt.ylim(1e-5, 1)
                                                  -    plt.plot([1e-5, 1], [1e-5, 1], color='black', lw=2, linestyle='--')
                                                  -    plt.xlabel('False Positive Rate')
                                                  -    plt.ylabel('True Positive Rate')
                                                  -    if neighbor_model_name:
                                                  -        plt.title(f'ROC Curves ({model_name} - {neighbor_model_name}) : low FPR region')
                                                  -    else:
                                                  -        plt.title(f'ROC Curves ({model_name} : low FPR region')
                                                  -    plt.legend(loc="lower right", fontsize=6)
                                                  -    plt.savefig(f"{save_folder}/roc_curves_low_fpr.png")
                                                  -
                                                  -
                                                  -def save_f1_histogram(f1_scores, save_folder):
                                                  -    """
                                                  -        Function for saving F1-score histograms.
                                                  -    """
                                                  -    plt.hist(f1_scores, bins=20, edgecolor='black', alpha=0.7)
                                                  -    plt.xlabel('F1 Score')
                                                  -    plt.ylabel('Frequency')
                                                  -    plt.title('Histogram of F1 Scores')
                                                  -    plt.savefig(f"{save_folder}/f1_hist.png")
                                                  -    plt.close()
                                                  -
                                                  -
                                                  -def save_ll_histograms(experiments, save_folder):
                                                  -    """
                                                  -        Save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment in experiments:
                                                  -        try:
                                                  -            results = experiment["raw_results"]
                                                  -            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
                                                  -            plt.figure(figsize=(20, 6))
                                                  -            plt.subplot(1, 2, 1)
                                                  -            plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='member')
                                                  -            plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled')
                                                  -            plt.xlabel("log likelihood")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.subplot(1, 2, 2)
                                                  -            plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='nonmember')
                                                  -            plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original')
                                                  -            plt.xlabel("log likelihood")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.savefig(
                                                  -                f"{save_folder}/ll_histograms_{experiment['name']}.png")
                                                  -        except:
                                                  -            pass
                                                  -
                                                  -
                                                  -def save_llr_histograms(experiments, save_folder):
                                                  -    """
                                                  -        Save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment in experiments:
                                                  -        try:
                                                  -            results = experiment["raw_results"]
                                                  -            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
                                                  -            plt.figure(figsize=(20, 6))
                                                  -            plt.subplot(1, 2, 1)
                                                  -
                                                  -            # compute the log likelihood ratio for each result
                                                  -            for r in results:
                                                  -                r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
                                                  -                r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
                                                  -            
                                                  -            plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='member')
                                                  -            plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='nonmember')
                                                  -            plt.xlabel("log likelihood ratio")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.savefig(f"{save_folder}/llr_histograms_{experiment['name']}.png")
                                                  -        except:
                                                  -            pass
                                                  -
                                                  @@ -169,151 +39,24 @@

                                                  Functions

                                                  Function for saving F1-score histograms.

                                                  -
                                                  - -Expand source code - -
                                                  def save_f1_histogram(f1_scores, save_folder):
                                                  -    """
                                                  -        Function for saving F1-score histograms.
                                                  -    """
                                                  -    plt.hist(f1_scores, bins=20, edgecolor='black', alpha=0.7)
                                                  -    plt.xlabel('F1 Score')
                                                  -    plt.ylabel('Frequency')
                                                  -    plt.title('Histogram of F1 Scores')
                                                  -    plt.savefig(f"{save_folder}/f1_hist.png")
                                                  -    plt.close()
                                                  -
                                                  def save_ll_histograms(experiments, save_folder)

                                                  Save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.

                                                  -
                                                  - -Expand source code - -
                                                  def save_ll_histograms(experiments, save_folder):
                                                  -    """
                                                  -        Save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment in experiments:
                                                  -        try:
                                                  -            results = experiment["raw_results"]
                                                  -            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
                                                  -            plt.figure(figsize=(20, 6))
                                                  -            plt.subplot(1, 2, 1)
                                                  -            plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='member')
                                                  -            plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled')
                                                  -            plt.xlabel("log likelihood")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.subplot(1, 2, 2)
                                                  -            plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='nonmember')
                                                  -            plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original')
                                                  -            plt.xlabel("log likelihood")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.savefig(
                                                  -                f"{save_folder}/ll_histograms_{experiment['name']}.png")
                                                  -        except:
                                                  -            pass
                                                  -
                                                  def save_llr_histograms(experiments, save_folder)

                                                  Save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.

                                                  -
                                                  - -Expand source code - -
                                                  def save_llr_histograms(experiments, save_folder):
                                                  -    """
                                                  -        Save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment in experiments:
                                                  -        try:
                                                  -            results = experiment["raw_results"]
                                                  -            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
                                                  -            plt.figure(figsize=(20, 6))
                                                  -            plt.subplot(1, 2, 1)
                                                  -
                                                  -            # compute the log likelihood ratio for each result
                                                  -            for r in results:
                                                  -                r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
                                                  -                r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
                                                  -            
                                                  -            plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='member')
                                                  -            plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='nonmember')
                                                  -            plt.xlabel("log likelihood ratio")
                                                  -            plt.ylabel('count')
                                                  -            plt.legend(loc='upper right')
                                                  -            plt.savefig(f"{save_folder}/llr_histograms_{experiment['name']}.png")
                                                  -        except:
                                                  -            pass
                                                  -
                                                  def save_roc_curves(experiments, save_folder, model_name, neighbor_model_name: str = None)

                                                  Save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors.

                                                  -
                                                  - -Expand source code - -
                                                  def save_roc_curves(experiments, save_folder, model_name, neighbor_model_name: str = None):
                                                  -    """
                                                  -        Save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors.
                                                  -    """
                                                  -    # first, clear plt
                                                  -    plt.clf()
                                                  -
                                                  -    for experiment, color in zip(experiments, COLORS):
                                                  -        metrics = experiment["metrics"]
                                                  -        plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
                                                  -        # print roc_auc for this experiment
                                                  -        print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
                                                  -    plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--')
                                                  -    plt.xlim([0.0, 1.0])
                                                  -    plt.ylim([0.0, 1.05])
                                                  -    plt.xlabel('False Positive Rate')
                                                  -    plt.ylabel('True Positive Rate')
                                                  -    if neighbor_model_name:
                                                  -        plt.title(f'ROC Curves ({model_name} - {neighbor_model_name})')
                                                  -    else:
                                                  -        plt.title(f'ROC Curves ({model_name})')
                                                  -    plt.legend(loc="lower right", fontsize=6)
                                                  -    plt.savefig(f"{save_folder}/roc_curves.png")
                                                  -
                                                  -    # Also plot ROC curves for low TPR-FPR region
                                                  -    plt.clf()
                                                  -    for experiment, color in zip(experiments, COLORS):
                                                  -        metrics = experiment["metrics"]
                                                  -        plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
                                                  -    plt.xscale('log')
                                                  -    plt.yscale('log')
                                                  -    plt.xlim(1e-5, 1)
                                                  -    plt.ylim(1e-5, 1)
                                                  -    plt.plot([1e-5, 1], [1e-5, 1], color='black', lw=2, linestyle='--')
                                                  -    plt.xlabel('False Positive Rate')
                                                  -    plt.ylabel('True Positive Rate')
                                                  -    if neighbor_model_name:
                                                  -        plt.title(f'ROC Curves ({model_name} - {neighbor_model_name}) : low FPR region')
                                                  -    else:
                                                  -        plt.title(f'ROC Curves ({model_name} : low FPR region')
                                                  -    plt.legend(loc="lower right", fontsize=6)
                                                  -    plt.savefig(f"{save_folder}/roc_curves_low_fpr.png")
                                                  -
                                                  @@ -326,7 +69,6 @@

                                                  Functions

                                                  MIMIR -

                                                  Index

                                                    @@ -348,7 +90,7 @@

                                                    Index

                                                    - \ No newline at end of file + diff --git a/docs/utils.html b/docs/utils.html index ed5bb72..fe6876c 100644 --- a/docs/utils.html +++ b/docs/utils.html @@ -2,19 +2,22 @@ - - + + mimir.utils API documentation - - - - - +This module provides utility functions …"> + + + + + - - + +
                                                    @@ -28,68 +31,6 @@

                                                    Module mimir.utils

                                                    Environment Variables: MIMIR_CACHE_PATH: The path to the cache directory. This should be set in the environment. MIMIR_DATA_SOURCE: The data source for the MIMIR project. This should be set in the environment.

                                                    -
                                                    - -Expand source code - -
                                                    """
                                                    -utils.py
                                                    -This module provides utility functions.
                                                    -
                                                    -Environment Variables:
                                                    -    MIMIR_CACHE_PATH: The path to the cache directory. This should be set in the environment.
                                                    -    MIMIR_DATA_SOURCE: The data source for the MIMIR project. This should be set in the environment.
                                                    -"""
                                                    -
                                                    -import os
                                                    -import random
                                                    -import torch as ch
                                                    -import numpy as np
                                                    -
                                                    -# Read environment variables
                                                    -CACHE_PATH = os.environ.get('MIMIR_CACHE_PATH', None)
                                                    -DATA_SOURCE = os.environ.get('MIMIR_DATA_SOURCE', None)
                                                    -
                                                    -
                                                    -def fix_seed(seed: int = 0):
                                                    -    """
                                                    -    Fix seed for reproducibility.
                                                    -
                                                    -    Parameters:
                                                    -        seed (int): The seed to set. Default is 0.
                                                    -    """
                                                    -    ch.manual_seed(seed)
                                                    -    np.random.seed(seed)
                                                    -    random.seed(seed)
                                                    -
                                                    -
                                                    -def get_cache_path():
                                                    -    """
                                                    -    Get path to cache directory.
                                                    -    Returns:
                                                    -        str: path to cache directory
                                                    -
                                                    -    Raises:
                                                    -        ValueError: If the MIMIR_CACHE_PATH environment variable is not set.
                                                    -    """
                                                    -    if CACHE_PATH is None:
                                                    -        raise ValueError('MIMIR_CACHE_PATH environment variable not set')
                                                    -    return CACHE_PATH
                                                    -
                                                    -
                                                    -def get_data_source():
                                                    -    """
                                                    -    Get path to data source directory.
                                                    -    Returns:
                                                    -        str: path to data source directory
                                                    -
                                                    -    Raises:
                                                    -        ValueError: If the MIMIR_DATA_SOURCE environment variable is not set.
                                                    -    """
                                                    -    if DATA_SOURCE is None:
                                                    -        raise ValueError('MIMIR_DATA_SOURCE environment variable not set')
                                                    -    return DATA_SOURCE
                                                    -
                                                    @@ -105,21 +46,6 @@

                                                    Functions

                                                    Fix seed for reproducibility.

                                                    Parameters

                                                    seed (int): The seed to set. Default is 0.

                                                    -
                                                    - -Expand source code - -
                                                    def fix_seed(seed: int = 0):
                                                    -    """
                                                    -    Fix seed for reproducibility.
                                                    -
                                                    -    Parameters:
                                                    -        seed (int): The seed to set. Default is 0.
                                                    -    """
                                                    -    ch.manual_seed(seed)
                                                    -    np.random.seed(seed)
                                                    -    random.seed(seed)
                                                    -
                                                    def get_cache_path() @@ -136,23 +62,6 @@

                                                    Raises

                                                    ValueError
                                                    If the MIMIR_CACHE_PATH environment variable is not set.
                                                    -
                                                    - -Expand source code - -
                                                    def get_cache_path():
                                                    -    """
                                                    -    Get path to cache directory.
                                                    -    Returns:
                                                    -        str: path to cache directory
                                                    -
                                                    -    Raises:
                                                    -        ValueError: If the MIMIR_CACHE_PATH environment variable is not set.
                                                    -    """
                                                    -    if CACHE_PATH is None:
                                                    -        raise ValueError('MIMIR_CACHE_PATH environment variable not set')
                                                    -    return CACHE_PATH
                                                    -
                                                    def get_data_source() @@ -169,23 +78,6 @@

                                                    Raises

                                                    ValueError
                                                    If the MIMIR_DATA_SOURCE environment variable is not set.
                                                    -
                                                    - -Expand source code - -
                                                    def get_data_source():
                                                    -    """
                                                    -    Get path to data source directory.
                                                    -    Returns:
                                                    -        str: path to data source directory
                                                    -
                                                    -    Raises:
                                                    -        ValueError: If the MIMIR_DATA_SOURCE environment variable is not set.
                                                    -    """
                                                    -    if DATA_SOURCE is None:
                                                    -        raise ValueError('MIMIR_DATA_SOURCE environment variable not set')
                                                    -    return DATA_SOURCE
                                                    -
                                                    @@ -198,7 +90,6 @@

                                                    Raises

                                                    MIMIR -

                                                    Index

                                                      @@ -219,7 +110,7 @@

                                                      Index

                                                      - \ No newline at end of file +