-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add Adversarial Diffusion Distillation (ADD) Script #6303
base: main
Are you sure you want to change the base?
Conversation
def transform(example): | ||
# resize image | ||
image = example["image"] | ||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bilinear causes image artifacting that impacts training quality in major ways. use LANCZOS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some evidence would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this configurable as well if users prefer to use different interpolations
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.prediction_type == "epsilon": | ||
pred_x_0 = (sample - sigmas * model_output) / alphas | ||
elif self.prediction_type == "sample": | ||
pred_x_0 = model_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the u-net returns the residual noise prediction, right? or is it not an intermediary phase with XL Turbo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the unet
returns residual noise prediction only when prediction_type=epsilon
, which is the default for most SD models. This is to support different prediction_types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well when we train on v_prediction it uses the residual returned from the unet as an input to get_velocity. ergo it is an intermediary stage for v-prediction in Diffusers training. but this code makes it appear as if the sample is returned directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine as the denoiser does exactly that, it returns the predicted original sample.
|
||
|
||
@torch.no_grad() | ||
def update_ema(target_params, source_params, rate=0.99): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not use EMAModel class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, +1 to that. Let's try using the EMAModel
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fully sure, but using EMA like this (fixed ema rate) might make sense for distillation as we might not want to change the model too much. So a fixed high enough value of rate
could be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, in EMAModel
there's currently no option to set a fixed decay rate? In get_decay
it looks like self.decay
is not used whether self.use_ema_warmup
is True
or False
:
diffusers/src/diffusers/training_utils.py
Lines 196 to 199 in 1fff527
if self.use_ema_warmup: | |
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power | |
else: | |
cur_decay_value = (1 + step) / (10 + step) |
In fact, it doesn't seem like self.decay
is used in any of the EMA logic at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me open a PR for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 please correct me if I am wrong here.
This is where self.decay
is used in get_decay()
:
diffusers/src/diffusers/training_utils.py
Line 201 in 1fff527
cur_decay_value = min(cur_decay_value, self.decay) |
Shouldn't that suffice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I missed that. I think a fixed decay rate can be achieved by setting decay == min_decay
.
action="store_true", | ||
help=( | ||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly" | ||
" cropped. The images will be resized to the resolution first before cropping." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no need to unconditionally resize before crop, especially on a diverse dataset. not resizing first allows better fine details to be learnt.
# Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
# TODO: is there a better way to implement this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't seem to be much point to this, since even @PeterL1n et al showed that zero-terminal SNR doesn't do anything meaningful for epsilon models, and SDXL Turbo doesn't use v-prediction unfortunately..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think,as the losses are computed in pixel-space, it could still have some effect for epsilon
prediction.
|
||
if accelerator.unwrap_model(unet).dtype != torch.float32: | ||
raise ValueError( | ||
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!= controlnet
|
||
# 9. Handle mixed precision and device placement | ||
# For mixed precision training we cast all non-trainable weigths to half-precision | ||
# as these weights are only used for inference, keeping weights in full precision is not required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference has a big difference with bfloat16 vs float16
student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) | ||
|
||
# 10. Handle saving and loading of checkpoints | ||
# `accelerate` 0.16.0 will have better support for customized saving |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're > 0.20.0 accelerate now
# Enable TF32 for faster training on Ampere GPUs, | ||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||
if args.allow_tf32: | ||
torch.backends.cuda.matmul.allow_tf32 = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other than matmul you can also enable:
torch.backends.cudnn.allow_tf32 = True
image, text = batch | ||
|
||
image = image.to(accelerator.device, non_blocking=True) | ||
encoded_text = compute_embeddings_fn(text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precomputing the embeds allows for lower VRAM use.
caption dropout should be implemented too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think caption dropout is not necessary for this, as here we are distilling the CFG score of the teacher model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's also fine to NOT precompute the text embeddings for now as we're aiming for a bigger training run here. We can revisit this later.
# encode pixel values with batch size of at most 32 | ||
latents = [] | ||
for i in range(0, pixel_values.shape[0], 32): | ||
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better not to encode latents on the fly, as that substantially increases vram use.
an option to recache the vae latents every epoch would be nice, since then the random crop and random flip are more functional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely okay to not consider this for now:
https://github.com/huggingface/diffusers/pull/6303/files#r1436066871
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) | ||
|
||
# get crop coordinates and crop image | ||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about args.center_crop ? or a crop that preserves aspect bucketing? legacy SD training greatly benefits from data bucketing.
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) | ||
|
||
# get crop coordinates and crop image | ||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this unconditionally crops the example even though the crop coords are used conditionally by the value of use_fix_crop_and_size
.
Additionally, it uses RandomCrop always, without args.center_crop being taken into account.
Further, it crops to a square every time.
this isn't necessary, you can preserve the aspect ratio.
else: | ||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
|
||
validation_prompts = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use a prompt dataset here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a reference example script, that ain't necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my suggestion was based on the other training example scripts :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! The script already covers most of the details for ADD. Left some comments.
Will try to give it a run in next few days.
def transform(example): | ||
# resize image | ||
image = example["image"] | ||
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this configurable as well if users prefer to use different interpolations
.
if self.prediction_type == "epsilon": | ||
pred_x_0 = (sample - sigmas * model_output) / alphas | ||
elif self.prediction_type == "sample": | ||
pred_x_0 = model_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the unet
returns residual noise prediction only when prediction_type=epsilon
, which is the default for most SD models. This is to support different prediction_types
# Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
# TODO: is there a better way to implement this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think,as the losses are computed in pixel-space, it could still have some effect for epsilon
prediction.
# 1. Decode real and fake (generated) latents back to pixel space. | ||
# NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the | ||
# pretrained feature network for the discriminator operates in pixel space rather than latent space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct!
student_gen_image = vae.decode(student_x_0).sample | ||
|
||
# 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. | ||
disc_output_real = discriminator(real_image, prompt_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What kind of image input does the dino model expect ? Since it's normalized with imagenet mean and std, we should convert the decoded images between 0-1
range. Like
real_image = (real_image / 2 + 0.5).clamp(0, 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I think we should resize the images here as dino expects 224x224
images iirc.
It's done in the FeatureNetwork
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did we check the range of expected inputs ?
lr_scheduler.step() | ||
|
||
# Checks if the accelerator has performed an optimization step behind the scenes | ||
if accelerator.sync_gradients: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should do ema update here.
optimizer.zero_grad(set_to_none=True) | ||
|
||
# 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator | ||
disc_output_fake = discriminator(student_gen_image, prompt_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This term is already computed above. Do we need to recompute it here ? Not sure because here we are not doing vanilla GAN training so we might as well be able to utilise that.
validation_prompts = [ | ||
"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", | ||
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | ||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | ||
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prompts from the ADD paper could also be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed prompts to those used in the ADD paper (note that all examples images from the paper are generated by ADD-XL).
|
||
# 1. Create the noise scheduler and the desired noise schedule. | ||
# Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
# TODO: is there a better way to implement this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your current implementation looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I'm currently using DDIMScheduler
, which currently supports rescale_betas_zero_snr
, but ideally I'd like to use DDPMScheduler
here, since my understanding is that DDPMScheduler
can typically load DDIMScheduler
configs but not vice versa. Since DDPMScheduler currently does not support rescale_betas_zero_snr
, I've opened a PR to add it: #6305.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can simply use Euler now, instead of DDIM, because it also supports zero-terminal SNR. this would match the behaviour of the ControlNet trainer, which uses Euler.
pixel_values = image.to(dtype=weight_dtype) | ||
if vae.dtype != weight_dtype: | ||
vae.to(dtype=weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As hinted by @patil-suraj above, we can safely always have the VAE in a reduced precision in case of SD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for SD 2.1, which ends up with NaN inside the VAE with half precision on finetuned models.. never figured that one out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, for SD2.1, that is the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it's noticeably an issue when finetuning a 2.1-v model in Diffusers and then trying to do inference later on an AMD system, which lacks xformers etc, the float32 VAE has OOM but float16 has NaN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work. Excited to see how the results come up!
I'd be keen on experimenting with simpler discriminators though. But that is obviously not a blocker.
Excuse my ignorance but how do I run it? I seem to get this error trying to run either the lora or SD1.5 version
Traceback (most recent call last): |
I think there are a few additional arguments that need to be explicitly supplied for the scripts to not raise an error. Something close to the minimal set of arguments needed is accelerate launch examples/add/train_add_distill_lora_sd_wds.py \
--pretrained_teacher_model="<teacher_model>" \
--train_shards_path_or_url="<dataset>" \
--output_dir="<output_dir>" \
--max_train_steps=1 \
--max_train_samples=20 \
--dataloader_num_workers=8 \ assuming the other default values work (for example, Note that the scripts are a work in progress and there's no guarantee that they work currently. |
Got it running, ran into bug saving though. Validation images also looked like random noise also.
|
d_r1_regularizer = 0 | ||
for k, head in discriminator.heads.items(): | ||
head_grad_params = torch.autograd.grad( | ||
outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to the paper, the r1 penalty seems to be computed w.r.t head's input instead of head's parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, I think you're right. In fact there seem to be several problems in the current implementation:
- As you pointed out, the gradient penalty should be calculated with respect to the discriminator head inputs. That's not available in the current code, but if the input to discriminator head
k
is available asfeatures[k]
, I think the fix would be to set theinputs
argument totorch.autograd.grad
tofeatures[k]
(if I understand autograd correctly). - It looks like I misunderstood the definition of the R1 gradient penalty. The ADD paper cites this paper when discussing the R1 gradient penalty, and the latter paper defines the R1 gradient penalty as
So it seems like we should be using the L2 norm rather than the L1 norm when calculating the gradient penalty. It's also possible that the implementation is off by a factor of
@patil-suraj @sayakpaul does this sound correct to you guys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, regarding formula it is l2 norm, something like this
d_r1_regularizer = sum((torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1) ** 2).mean() for grad in feature_grads)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a tentative fix for the discriminator R1 gradient penalty for the SD ADD script in commit ab46142. In particular this part
diffusers/examples/add/train_add_distill_sd_wds.py
Lines 1860 to 1862 in ab46142
for k, feature in features_real.items(): | |
# Required so that the torch.autograd.grad call below works properly? | |
feature.requires_grad_(True) |
feels weird to me but seems necessary because the feature
s in features_real
don't usually have gradients because the feature_network
is frozen. @erliding @patil-suraj @sayakpaul would be great if you could look this over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, you need to explicitly call feature.requires_grad_()
before passing them to heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul as far as my understanding, the gradient penalty could help enforce Lipschitz continuity (thus gradient of output w.r.t input) on discriminator which is a requirement of Wasserstein Gan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but we aren't using the earth mover distance here in the loss no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it's hinge loss here :) gradient penalty should bring similar benefits though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dg845 It does look good to me. We need to enable grads for feature inputs to be able to compare the grad penalty, if the input does not have grad enabled I think create_graph=True
will complain.
for reference, DDGAN training has this
…'t fixed autograd call yet).
|
||
if accelerator.sync_gradients: | ||
accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) | ||
discriminator_optimizer.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i understand correct, when the gradient_accumulation_steps > 1
current implementation seems result in a simultaneous gradient descent up to the last accumulation step for each batch, when gradient_accumulation_steps == 1
it is alternating gradient descent, while in the stylegan-t it is always alternating gradient descent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How important is that to follow? Is it absolutely a must for training stability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convergence behavior of simultaneous gradient descent and alternating gradient descent are different when achieving min-max equilibrium, from those gan implementations lately, seems alternating gradient descent is usually adopted, not sure how important they could help for the case of ADD though
…ator R1 gradient penalty.
if args.use_image_conditioning: | ||
image_embedding = encoded_image.pop("image_embeds").float() | ||
# Only supply image conditioning when student timestep is not last training timestep T. | ||
image_embedding = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently the same masked image_embedding
is fed to discriminator for both real and fake images, but it seems for which real image its image_embedding
being masked out could be random with a rate say 1 / num_inference_steps
instead of depending on student_timesteps
, not sure if this could make big difference though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah same. I am sure either how impactful this would be.
student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() | ||
student_timesteps = student_timestep_schedule[student_index] | ||
teacher_timesteps = torch.randint( | ||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might better not to sample teacher_timesteps
from the full range of [0, noise_scheduler.config.num_train_timesteps)
but instead ignoring timesteps that are too small or too big, e.g. the default configurable range from dream fusion is [0.02, 0.98] * num_train_timesteps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess Figure 2 of the ADD paper implies that they sample from the full range of teacher timesteps:
But this is definitely something we can try out :).
How far away is this pr from being merged? |
Hi @SteamedGit the ADD implementation is nominally complete but I have not been able to test whether the script can distill good models (e.g. for SD v1.5) yet. |
…itive value instead of zero following EulerDiscreteScheduler.
…y whether we use a CLIPTextModel or CLIPTextModelWithProjection (e.g. with --use_pretrained_projection).
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Not stale. |
@sayakpaul @dg845 Great job! Can someone please confirm if the effectiveness of this PR has been verified?@ |
regarding computing sds loss i suggest taking a look at https://arxiv.org/abs/2306.04619 which tends to produce a better target |
@cjt222 sorry, I haven't been able to finish testing it yet. Will hopefully find more time to work on it soon 😅. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
This PR adds an example script for adversarial diffusion distillation (ADD) (paper, code), a distillation + adversarial training method used to distill SD/SD-XL Turbo.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten
@sayakpaul
@patil-suraj