Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Adversarial Diffusion Distillation (ADD) Script #6303

Open
wants to merge 58 commits into
base: main
Choose a base branch
from

Conversation

dg845
Copy link
Contributor

@dg845 dg845 commented Dec 23, 2023

What does this PR do?

This PR adds an example script for adversarial diffusion distillation (ADD) (paper, code), a distillation + adversarial training method used to distill SD/SD-XL Turbo.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten
@sayakpaul
@patil-suraj

@dg845 dg845 changed the title Add Adversarial Diffusion Distillation (ADD) Script [WIP] Add Adversarial Diffusion Distillation (ADD) Script Dec 24, 2023
def transform(example):
# resize image
image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bilinear causes image artifacting that impacts training quality in major ways. use LANCZOS.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some evidence would be nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this configurable as well if users prefer to use different interpolations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original img: too big to preview

bilinear: image

lanczos: image

if self.prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif self.prediction_type == "sample":
pred_x_0 = model_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the u-net returns the residual noise prediction, right? or is it not an intermediary phase with XL Turbo?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the unet returns residual noise prediction only when prediction_type=epsilon , which is the default for most SD models. This is to support different prediction_types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well when we train on v_prediction it uses the residual returned from the unet as an input to get_velocity. ergo it is an intermediary stage for v-prediction in Diffusers training. but this code makes it appear as if the sample is returned directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine as the denoiser does exactly that, it returns the predicted original sample.



@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use EMAModel class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, +1 to that. Let's try using the EMAModel class.

Copy link
Contributor

@patil-suraj patil-suraj Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not fully sure, but using EMA like this (fixed ema rate) might make sense for distillation as we might not want to change the model too much. So a fixed high enough value of rate could be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, in EMAModel there's currently no option to set a fixed decay rate? In get_decay it looks like self.decay is not used whether self.use_ema_warmup is True or False:

if self.use_ema_warmup:
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
else:
cur_decay_value = (1 + step) / (10 + step)

In fact, it doesn't seem like self.decay is used in any of the EMA logic at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me open a PR for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dg845 please correct me if I am wrong here.

This is where self.decay is used in get_decay():

cur_decay_value = min(cur_decay_value, self.decay)

Shouldn't that suffice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I missed that. I think a fixed decay rate can be achieved by setting decay == min_decay.

action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need to unconditionally resize before crop, especially on a diverse dataset. not resizing first allows better fine details to be learnt.

Comment on lines 1111 to 1112
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
# TODO: is there a better way to implement this?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't seem to be much point to this, since even @PeterL1n et al showed that zero-terminal SNR doesn't do anything meaningful for epsilon models, and SDXL Turbo doesn't use v-prediction unfortunately..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think,as the losses are computed in pixel-space, it could still have some effect for epsilon prediction.


if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= controlnet


# 9. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference has a big difference with bfloat16 vs float16

student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device)

# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're > 0.20.0 accelerate now

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other than matmul you can also enable:

torch.backends.cudnn.allow_tf32 = True

image, text = batch

image = image.to(accelerator.device, non_blocking=True)
encoded_text = compute_embeddings_fn(text)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precomputing the embeds allows for lower VRAM use.

caption dropout should be implemented too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think caption dropout is not necessary for this, as here we are distilling the CFG score of the teacher model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's also fine to NOT precompute the text embeddings for now as we're aiming for a bigger training run here. We can revisit this later.

# encode pixel values with batch size of at most 32
latents = []
for i in range(0, pixel_values.shape[0], 32):
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better not to encode latents on the fly, as that substantially increases vram use.

an option to recache the vae latents every epoch would be nice, since then the random crop and random flip are more functional.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely okay to not consider this for now:
https://github.com/huggingface/diffusers/pull/6303/files#r1436066871

image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)

# get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about args.center_crop ? or a crop that preserves aspect bucketing? legacy SD training greatly benefits from data bucketing.

image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)

# get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this unconditionally crops the example even though the crop coords are used conditionally by the value of use_fix_crop_and_size.

Additionally, it uses RandomCrop always, without args.center_crop being taken into account.

Further, it crops to a square every time.

this isn't necessary, you can preserve the aspect ratio.

else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)

validation_prompts = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use a prompt dataset here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a reference example script, that ain't necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion was based on the other training example scripts :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! The script already covers most of the details for ADD. Left some comments.
Will try to give it a run in next few days.

def transform(example):
# resize image
image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this configurable as well if users prefer to use different interpolations.

if self.prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif self.prediction_type == "sample":
pred_x_0 = model_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the unet returns residual noise prediction only when prediction_type=epsilon , which is the default for most SD models. This is to support different prediction_types

examples/add/train_add_distill_sd_wds.py Outdated Show resolved Hide resolved
examples/add/train_add_distill_sd_wds.py Outdated Show resolved Hide resolved
Comment on lines 1111 to 1112
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
# TODO: is there a better way to implement this?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think,as the losses are computed in pixel-space, it could still have some effect for epsilon prediction.

Comment on lines +1530 to +1532
# 1. Decode real and fake (generated) latents back to pixel space.
# NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the
# pretrained feature network for the discriminator operates in pixel space rather than latent space.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct!

examples/add/train_add_distill_sd_wds.py Outdated Show resolved Hide resolved
student_gen_image = vae.decode(student_x_0).sample

# 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively.
disc_output_real = discriminator(real_image, prompt_embeds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of image input does the dino model expect ? Since it's normalized with imagenet mean and std, we should convert the decoded images between 0-1 range. Like

real_image = (real_image / 2 + 0.5).clamp(0, 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I think we should resize the images here as dino expects 224x224 images iirc.

It's done in the FeatureNetwork class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we check the range of expected inputs ?

lr_scheduler.step()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do ema update here.

optimizer.zero_grad(set_to_none=True)

# 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator
disc_output_fake = discriminator(student_gen_image, prompt_embeds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This term is already computed above. Do we need to recompute it here ? Not sure because here we are not doing vanilla GAN training so we might as well be able to utilise that.

Comment on lines 596 to 601
validation_prompts = [
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompts from the ADD paper could also be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed prompts to those used in the ADD paper (note that all examples images from the paper are generated by ADD-XL).


# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
# TODO: is there a better way to implement this?
Copy link
Member

@sayakpaul sayakpaul Dec 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your current implementation looks good to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I'm currently using DDIMScheduler, which currently supports rescale_betas_zero_snr, but ideally I'd like to use DDPMScheduler here, since my understanding is that DDPMScheduler can typically load DDIMScheduler configs but not vice versa. Since DDPMScheduler currently does not support rescale_betas_zero_snr, I've opened a PR to add it: #6305.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can simply use Euler now, instead of DDIM, because it also supports zero-terminal SNR. this would match the behaviour of the ControlNet trainer, which uses Euler.

Comment on lines +1471 to +1473
pixel_values = image.to(dtype=weight_dtype)
if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As hinted by @patil-suraj above, we can safely always have the VAE in a reduced precision in case of SD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for SD 2.1, which ends up with NaN inside the VAE with half precision on finetuned models.. never figured that one out

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, for SD2.1, that is the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's noticeably an issue when finetuning a 2.1-v model in Diffusers and then trying to do inference later on an AMD system, which lacks xformers etc, the float32 VAE has OOM but float16 has NaN.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work. Excited to see how the results come up!

I'd be keen on experimenting with simpler discriminators though. But that is obviously not a blocker.

@Metal079
Copy link

Metal079 commented Jan 8, 2024

Excuse my ignorance but how do I run it? I seem to get this error trying to run either the lora or SD1.5 version

accelerate launch train_add_distill_lora_sd_wds.py --pretrained_teacher_model C:\Users\Pablo\Documents\mobians_api\sonicDiffusionV4 --train_shards_path_or_url laion/conceptual-captions-12m-webdataset

Traceback (most recent call last):
File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 2114, in
main(args)
File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 1693, in main
dataset = SDText2ImageDataset(
File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 238, in init
num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'

@dg845
Copy link
Contributor Author

dg845 commented Jan 10, 2024

I think there are a few additional arguments that need to be explicitly supplied for the scripts to not raise an error. Something close to the minimal set of arguments needed is

accelerate launch examples/add/train_add_distill_lora_sd_wds.py \
    --pretrained_teacher_model="<teacher_model>" \
    --train_shards_path_or_url="<dataset>" \
    --output_dir="<output_dir>" \
    --max_train_steps=1 \
    --max_train_samples=20 \
    --dataloader_num_workers=8 \

assuming the other default values work (for example, --train_batch_size is not so big that it leads to an OOM error).

Note that the scripts are a work in progress and there's no guarantee that they work currently.

@Metal079
Copy link

Metal079 commented Jan 10, 2024

Got it running, ran into bug saving though. Validation images also looked like random noise also.

Steps:  40%|▍| 400/1000 [35:56<53:16,  5.33s/it, d_total_loss=1.46, g_adv_loss=-.086, g_distill_loss=0.123, g_total_loss01/10/2024 16:24:08 - INFO - __main__ - Running validation...
                                                                                                                       Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of ../sonicDiffusionV4.               | 0/5 [00:00<?, ?it/s]
Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of ../sonicDiffusionV4.
                                                                                                                       Loaded scheduler as PNDMScheduler from `scheduler` subfolder of ../sonicDiffusionV4.       | 3/5 [00:00<00:00,  5.95it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:00<00:00,  9.88it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Steps:  50%|▌| 500/1000 [44:53<43:40,  5.24s/it, d_total_loss=1.51, g_adv_loss=-.227, g_distill_loss=0.00964, g_total_lo01/10/2024 16:33:06 - INFO - accelerate.accelerator - Saving current state to add_model/checkpoint-500
Configuration saved in add_model/checkpoint-500/unet/config.json
Model weights saved in add_model/checkpoint-500/unet/diffusion_pytorch_model.safetensors
Traceback (most recent call last):
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 2010, in <module>
    main(args)
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1957, in main
    accelerator.save_state(save_path)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2706, in save_state
    hook(self._models, weights, output_dir)
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1500, in save_model_hook
    model.save_pretrained(os.path.join(output_dir, "unet"))
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Discriminator' object has no attribute 'save_pretrained'
wandb: \ 17.403 MB of 17.417 MB uploaded
wandb: Run history:
wandb:  d_adv_loss_fake ▆▆▃▄▂▃▄▅▄▄▇▃▅▃█▇▃▆▂▃█▅▄▂▂▂▅▃▂▃▄▇▆▄▃▃▆▁▁▄
wandb:  d_adv_loss_real ▆▆▆▄▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb:      d_loss_real ▆▆▆▅▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb: d_r1_regularizer ███▆▇▄▆▄▇▆▄▅▃▆▁▁▂▄▆▂▃▄▄▂▄▃▃▅▃▂▆▁▄▃▅▂▁▄▇▄
wandb:     d_total_loss ██▅▅▅▄▅▅▆▅▇▄▅▅▇▆▃▇▅▃█▅▅▁▂▂▅▃▂▃▅▅▆▄▄▂▅▁▄▅
wandb:       g_adv_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▂▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb:   g_distill_loss ▄▃▃▂▃▁▅▃▅▁▂▃▃▄▁▁▂▃▂▂▂▁▂▂▁▁▅▃▅▂▃▁▂▂▁▂█▂▁▃
wandb:     g_total_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▃▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb:               lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:
wandb: Run summary:
wandb:  d_adv_loss_fake 1.29867
wandb:  d_adv_loss_real 0.19679
wandb:      d_loss_real 0.20999
wandb: d_r1_regularizer 1319.4856
wandb:     d_total_loss 1.50865
wandb:       g_adv_loss -0.22699
wandb:   g_distill_loss 0.00964
wandb:     g_total_loss -0.20288
wandb:               lr 0.0001
wandb:
wandb: 🚀 View run unique-armadillo-16 at: https://wandb.ai/metal/text2image-fine-tune/runs/0tks0q98
wandb: Synced 5 W&B file(s), 32 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240110_154811-0tks0q98/logs
Traceback (most recent call last):
  File "/home/metal/dpo_test/venv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
    simple_launcher(args)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/metal/dpo_test/venv/bin/python3', 'train_add_distill_sd_wds.py', '--pretrained_teacher_model=../sonicDiffusionV4', '--train_shards_path_or_url=pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true', '--output_dir=add_model', '--max_train_steps=1000', '--max_train_samples=4000000', '--dataloader_num_workers=8', '--train_batch_size=2', '--allow_tf32', '--mixed_precision=fp16', '--report_to=wandb', '--gradient_checkpointing', '--use_8bit_adam', '--gradient_accumulation_steps=8', '--allow_nonzero_terminal_snr']' returned non-zero exit status 1.


d_r1_regularizer = 0
for k, head in discriminator.heads.items():
head_grad_params = torch.autograd.grad(
outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to the paper, the r1 penalty seems to be computed w.r.t head's input instead of head's parameters?

Copy link
Contributor Author

@dg845 dg845 Jan 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, I think you're right. In fact there seem to be several problems in the current implementation:

  • As you pointed out, the gradient penalty should be calculated with respect to the discriminator head inputs. That's not available in the current code, but if the input to discriminator head k is available as features[k], I think the fix would be to set the inputs argument to torch.autograd.grad to features[k] (if I understand autograd correctly).
  • It looks like I misunderstood the definition of the R1 gradient penalty. The ADD paper cites this paper when discussing the R1 gradient penalty, and the latter paper defines the R1 gradient penalty as
$$R_1(\phi) = \frac{1}{2}\mathbb{E}_{x_0 \sim p_{data}(x)}[||\nabla D_{\phi, k}(F_k(x_0))||^2]$$

So it seems like we should be using the L2 norm rather than the L1 norm when calculating the gradient penalty. It's also possible that the implementation is off by a factor of $\frac{1}{2}$; I'm not sure if they absorb the factor of $\frac{1}{2}$ into $\gamma$ (and their reported value of $\gamma = 10^{-5}$).

@patil-suraj @sayakpaul does this sound correct to you guys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, regarding formula it is l2 norm, something like this

d_r1_regularizer = sum((torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1) ** 2).mean() for grad in feature_grads)

Copy link
Contributor Author

@dg845 dg845 Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a tentative fix for the discriminator R1 gradient penalty for the SD ADD script in commit ab46142. In particular this part

for k, feature in features_real.items():
# Required so that the torch.autograd.grad call below works properly?
feature.requires_grad_(True)

feels weird to me but seems necessary because the features in features_real don't usually have gradients because the feature_network is frozen. @erliding @patil-suraj @sayakpaul would be great if you could look this over

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, you need to explicitly call feature.requires_grad_() before passing them to heads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul the authors report that it is pretty useful:

add_paper_disc_grad_penalty

(from Section 3.2 of the ADD paper)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul as far as my understanding, the gradient penalty could help enforce Lipschitz continuity (thus gradient of output w.r.t input) on discriminator which is a requirement of Wasserstein Gan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but we aren't using the earth mover distance here in the loss no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's hinge loss here :) gradient penalty should bring similar benefits though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dg845 It does look good to me. We need to enable grads for feature inputs to be able to compare the grad penalty, if the input does not have grad enabled I think create_graph=True will complain.

for reference, DDGAN training has this


if accelerator.sync_gradients:
accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)
discriminator_optimizer.step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i understand correct, when the gradient_accumulation_steps > 1current implementation seems result in a simultaneous gradient descent up to the last accumulation step for each batch, when gradient_accumulation_steps == 1 it is alternating gradient descent, while in the stylegan-t it is always alternating gradient descent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is that to follow? Is it absolutely a must for training stability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convergence behavior of simultaneous gradient descent and alternating gradient descent are different when achieving min-max equilibrium, from those gan implementations lately, seems alternating gradient descent is usually adopted, not sure how important they could help for the case of ADD though

if args.use_image_conditioning:
image_embedding = encoded_image.pop("image_embeds").float()
# Only supply image conditioning when student timestep is not last training timestep T.
image_embedding = torch.where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently the same masked image_embedding is fed to discriminator for both real and fake images, but it seems for which real image its image_embedding being masked out could be random with a rate say 1 / num_inference_steps instead of depending on student_timesteps, not sure if this could make big difference though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same. I am sure either how impactful this would be.

student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long()
student_timesteps = student_timestep_schedule[student_index]
teacher_timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might better not to sample teacher_timesteps from the full range of [0, noise_scheduler.config.num_train_timesteps) but instead ignoring timesteps that are too small or too big, e.g. the default configurable range from dream fusion is [0.02, 0.98] * num_train_timesteps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess Figure 2 of the ADD paper implies that they sample from the full range of teacher timesteps:

add_figure_2_teacher_timesteps

But this is definitely something we can try out :).

@SteamedGit
Copy link

How far away is this pr from being merged?

@dg845
Copy link
Contributor Author

dg845 commented Feb 5, 2024

Hi @SteamedGit the ADD implementation is nominally complete but I have not been able to test whether the script can distill good models (e.g. for SD v1.5) yet.

dg845 added 4 commits February 5, 2024 23:26
…itive value instead of zero following EulerDiscreteScheduler.
…y whether we use a CLIPTextModel or CLIPTextModelWithProjection (e.g. with --use_pretrained_projection).
Copy link

github-actions bot commented Mar 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 5, 2024
@sayakpaul
Copy link
Member

Not stale.

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 6, 2024
@cjt222
Copy link

cjt222 commented Mar 13, 2024

@sayakpaul @dg845 Great job! Can someone please confirm if the effectiveness of this PR has been verified?@

@erliding
Copy link
Contributor

regarding computing sds loss i suggest taking a look at https://arxiv.org/abs/2306.04619 which tends to produce a better target

@dg845
Copy link
Contributor Author

dg845 commented Mar 18, 2024

@cjt222 sorry, I haven't been able to finish testing it yet. Will hopefully find more time to work on it soon 😅.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants