Skip to content

Commit

Permalink
Return features from Discriminator and use L2 gradient penalty (haven…
Browse files Browse the repository at this point in the history
…'t fixed autograd call yet).
  • Loading branch information
dg845 committed Jan 14, 2024
1 parent acf5175 commit cd82565
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 16 deletions.
10 changes: 6 additions & 4 deletions examples/add/train_add_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import shutil
import types
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import accelerate
import numpy as np
Expand Down Expand Up @@ -573,6 +573,7 @@ class DiscriminatorOutput(BaseOutput):
"""

logits: torch.FloatTensor
features: Optional[Dict[str, torch.FloatTensor]] = None


# Based on ProjectedDiscriminator from the official StyleGAN-T code
Expand Down Expand Up @@ -651,7 +652,7 @@ def forward(
if not return_dict:
return (logits,)

return DiscriminatorOutput(logits=logits)
return DiscriminatorOutput(logits=logits, features=features)


def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"):
Expand Down Expand Up @@ -1951,7 +1952,8 @@ def compute_image_embeddings(image_batch, image_encoder):
)
head_grad_norm = 0
for grad in head_grad_params:
head_grad_norm += grad.abs().sum()
head_grad_norm += grad.pow(2).sum()
head_grad_norm = head_grad_norm.sqrt()
d_r1_regularizer += head_grad_norm

d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer
Expand Down Expand Up @@ -2074,7 +2076,7 @@ def compute_image_embeddings(image_batch, image_encoder):
# Write out additional values for accelerator to report.
logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item()
logs["d_adv_loss_real"] = d_adv_loss_real.detach().item()
logs["d_r1_regularizer"] = d_r1_regularizer.detach().item()
logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item()
logs["d_loss_real"] = d_loss_real.detach().item()
accelerator.log(logs, step=global_step)

Expand Down
10 changes: 6 additions & 4 deletions examples/add/train_add_distill_lora_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import shutil
import types
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import accelerate
import numpy as np
Expand Down Expand Up @@ -590,6 +590,7 @@ class DiscriminatorOutput(BaseOutput):
"""

logits: torch.FloatTensor
features: Optional[Dict[str, torch.FloatTensor]] = None


# Based on ProjectedDiscriminator from the official StyleGAN-T code
Expand Down Expand Up @@ -668,7 +669,7 @@ def forward(
if not return_dict:
return (logits,)

return DiscriminatorOutput(logits=logits)
return DiscriminatorOutput(logits=logits, features=features)


def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"):
Expand Down Expand Up @@ -2044,7 +2045,8 @@ def compute_image_embeddings(image_batch, image_encoder):
)
head_grad_norm = 0
for grad in head_grad_params:
head_grad_norm += grad.abs().sum()
head_grad_norm += grad.pow(2).sum()
head_grad_norm = head_grad_norm.sqrt()
d_r1_regularizer += head_grad_norm

d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer
Expand Down Expand Up @@ -2173,7 +2175,7 @@ def compute_image_embeddings(image_batch, image_encoder):
# Write out additional values for accelerator to report.
logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item()
logs["d_adv_loss_real"] = d_adv_loss_real.detach().item()
logs["d_r1_regularizer"] = d_r1_regularizer.detach().item()
logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item()
logs["d_loss_real"] = d_loss_real.detach().item()
accelerator.log(logs, step=global_step)

Expand Down
10 changes: 6 additions & 4 deletions examples/add/train_add_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import shutil
import types
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import accelerate
import numpy as np
Expand Down Expand Up @@ -555,6 +555,7 @@ class DiscriminatorOutput(BaseOutput):
"""

logits: torch.FloatTensor
features: Optional[Dict[str, torch.FloatTensor]] = None


# Based on ProjectedDiscriminator from the official StyleGAN-T code
Expand Down Expand Up @@ -633,7 +634,7 @@ def forward(
if not return_dict:
return (logits,)

return DiscriminatorOutput(logits=logits)
return DiscriminatorOutput(logits=logits, features=features)


def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"):
Expand Down Expand Up @@ -1859,7 +1860,8 @@ def compute_image_embeddings(image_batch, image_encoder):
)
head_grad_norm = 0
for grad in head_grad_params:
head_grad_norm += grad.abs().sum()
head_grad_norm += grad.pow(2).sum()
head_grad_norm = head_grad_norm.sqrt()
d_r1_regularizer += head_grad_norm

d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer
Expand Down Expand Up @@ -1982,7 +1984,7 @@ def compute_image_embeddings(image_batch, image_encoder):
# Write out additional values for accelerator to report.
logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item()
logs["d_adv_loss_real"] = d_adv_loss_real.detach().item()
logs["d_r1_regularizer"] = d_r1_regularizer.detach().item()
logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item()
logs["d_loss_real"] = d_loss_real.detach().item()
accelerator.log(logs, step=global_step)

Expand Down
10 changes: 6 additions & 4 deletions examples/add/train_add_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import shutil
import types
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import accelerate
import numpy as np
Expand Down Expand Up @@ -572,6 +572,7 @@ class DiscriminatorOutput(BaseOutput):
"""

logits: torch.FloatTensor
features: Optional[Dict[str, torch.FloatTensor]] = None


# Based on ProjectedDiscriminator from the official StyleGAN-T code
Expand Down Expand Up @@ -650,7 +651,7 @@ def forward(
if not return_dict:
return (logits,)

return DiscriminatorOutput(logits=logits)
return DiscriminatorOutput(logits=logits, features=features)


def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"):
Expand Down Expand Up @@ -1952,7 +1953,8 @@ def compute_image_embeddings(image_batch, image_encoder):
)
head_grad_norm = 0
for grad in head_grad_params:
head_grad_norm += grad.abs().sum()
head_grad_norm += grad.pow(2).sum()
head_grad_norm = head_grad_norm.sqrt()
d_r1_regularizer += head_grad_norm

d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer
Expand Down Expand Up @@ -2081,7 +2083,7 @@ def compute_image_embeddings(image_batch, image_encoder):
# Write out additional values for accelerator to report.
logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item()
logs["d_adv_loss_real"] = d_adv_loss_real.detach().item()
logs["d_r1_regularizer"] = d_r1_regularizer.detach().item()
logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item()
logs["d_loss_real"] = d_loss_real.detach().item()
accelerator.log(logs, step=global_step)

Expand Down

0 comments on commit cd82565

Please sign in to comment.