Detecting, Explaining, and Mitigating Memorization in Diffusion Models ( + geometric functionalities)
This is a fork of the original repository, Detecting, Explaining, and Mitigating Memorization in Diffusion Models. We have included some local intrinsic dimension (LID) functionalities for mitigating memorization to this repo. We use an LID estimator from the paper A Geometric View of Data Complexity which proposes a method that is differentiable by design and use that as a replacement for the classifier free oprimization (CFG) norm that was originally proposed in the paper. We adapt FLIPD for diffusion denoising implicit models (DDIM) that has been used to run the experiments in the paper. We also include pure score norm oprimization and see how it compares to the other two methods.
In addition to that, we have also included a GPT-based method for prompt perturbation using token attributions obtained from these methods. For the base functionalities, please consult the original repository.
Remark 1: For all the command-line scripts here, you have control over the configuration and hyperparameters with hydra, please check configs/
for more details.
Remark 2: All the logs in this project will be stored in the outputs
directory. There are also some default logs stored in the repo to get to know the structure of the logs.
Run the following script to set up the environment:
conda env create -f env.yml
conda activate geometric_diffusion_memorization
When running the following script, it will take prompts that are stored in the match_verbatim_captions.json
and run attribution methods on them. Attribution methods basically take a caption and see how influential a token is for memorization. We have three attribution methods, cfg_norm
, score_norm
, and flipd
. The first one is what was proposed in the original paper, the second one is a simple score norm, and the third one is the LID-based method.
# get the token attributions for a set of memorized prompts
python store_attributions.py attribution_method=<score_norm|cfg_norm|flipd>
This will store a json file in outputs/attributions
that contains all the token attributions obtained via any of the metrics above! Note that you can also specify your own set of prompts.
After obtaining these attributions, we can use that information to perturb the prompts. For that, we will use the OpenAI API to generate new relevant prompts that do not contain these sensitive tokens that are causing memorization.
Prompt perturbation is done through our perturb_gpt.py
script. Before that, you have to store your OPENAI_API_KEY
in a .env
file in the root directory of this repo, meaning, you should have a file .env
with the following content:
OPENAI_API_KEY='<your-key>'
You can then run the following script:
python perturb_gpt.py attribution_method=<score_norm|cfg_norm|flipd|random>
This script looks at the already stored attributions in outputs/attributions
, therefore, make sure to run the last part first before running this script -- we have however included some default attributions in the repo. Note that we also have a random
method that perturbs the prompts randomly. Finally, after running this script all the perturbed prompts will be stored in a separate file in outputs/perturbed_prompts
.
This method is the analogue of the main inference-time mitigation method described in Detecting, Explaining, and Mitigating Memorization in Diffusion Models, where during the generation process, a prompt embedding is optimized to minimize the memorization metric, which in that case was the CFG norm. We have also included the LID-based method and also included the score norm method. You can run the following script to optimize the prompts:
python inference_time_mitigation_optimization.py mitigation_method=<score_norm|cfg_norm|flipd>
Similar to the perturbation script, this script also looks at prompts that are stored in the match_verbatim_captions.json
file. After running this script, the optimized prompts thorough the history of optimization will be passed through the generator network and stored in outputs/inference_time_mitigation
. There are examples already stored in the repo.