Skip to content

Commit

Permalink
[not sure if correct] Add tentative fixed implementation of discrimin…
Browse files Browse the repository at this point in the history
…ator R1 gradient penalty.
  • Loading branch information
dg845 committed Jan 15, 2024
1 parent cd82565 commit ab46142
Showing 1 changed file with 46 additions and 31 deletions.
77 changes: 46 additions & 31 deletions examples/add/train_add_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,33 +609,46 @@ def train(self, mode: bool = True):
def eval(self):
return self.train(False)

def forward(
self,
x: torch.Tensor,
c_text: torch.Tensor,
c_img: Optional[torch.Tensor] = None,
transform_positive: bool = True,
return_dict: bool = True,
):
# TODO: do we need the augmentations from the original StyleGAN-T code?
def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]:
if transform_positive:
# Transform to [0, 1].
x = x.add(1).div(2)
image = image.add(1).div(2)

# Forward pass through feature network.
features = self.feature_network(x)
features = self.feature_network(image)
return features

def forward_features(
self,
features: Dict[str, torch.Tensor],
c_text: torch.Tensor,
c_img: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size = features["0"].size(0)
# Apply discriminator heads.
logits = []
for k, head in self.heads.items():
logits.append(head(features[k], c_text, c_img).view(x.size(0), -1))
logits.append(head(features[k], c_text, c_img).view(batch_size, -1))
logits = torch.cat(logits, dim=1)

if not return_dict:
return (logits,)

return DiscriminatorOutput(logits=logits, features=features)

def forward(
self,
image: torch.Tensor,
c_text: torch.Tensor,
c_img: Optional[torch.Tensor] = None,
transform_positive: bool = True,
return_dict: bool = True,
):
features = self.get_features(image, transform_positive=transform_positive)
d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict)
return d_output


def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"):
logger.info("Running validation... ")
Expand Down Expand Up @@ -1777,7 +1790,7 @@ def compute_image_embeddings(image_batch, image_encoder):
text_embedding = encoded_text.pop("text_embedding")
image_embedding = None
if args.use_image_conditioning:
image_embedding = encoded_image.pop("image_embeds")
image_embedding = encoded_image.pop("image_embeds").float()
# Only supply image conditioning when student timestep is not last training timestep T.
image_embedding = torch.where(
student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1,
Expand Down Expand Up @@ -1842,32 +1855,34 @@ def compute_image_embeddings(image_batch, image_encoder):
)
student_gen_image = torch.cat(student_gen_image, dim=0)

# 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively.
disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding)
disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding)

# 3. Calculate the discriminator real adversarial loss terms.
d_logits_real = disc_output_real.logits
# 2. Calculate the discriminator real adversarial loss terms.
features_real = discriminator.get_features(pixel_values.float())
for k, feature in features_real.items():
# Required so that the torch.autograd.grad call below works properly?
feature.requires_grad_(True)
d_logits_real = discriminator.forward_features(
features_real, text_embedding.float(), image_embedding, return_dict=False
)[0]
# Use hinge loss (see section 3.2, Equation 3 of paper)
d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real))

# 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real
# data.
# 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the
# discriminator head input features from the real data.
d_r1_regularizer = 0
for k, head in discriminator.heads.items():
head_grad_params = torch.autograd.grad(
outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True
)
head_grad_norm = 0
for grad in head_grad_params:
head_grad_norm += grad.pow(2).sum()
head_grad_norm = head_grad_norm.sqrt()
d_r1_regularizer += head_grad_norm
grad_params = torch.autograd.grad(
outputs=d_adv_loss_real,
inputs=features_real.values(),
create_graph=True,
)
for grad in grad_params:
d_r1_regularizer += grad.pow(2).sum()
d_r1_regularizer = d_r1_regularizer.sqrt()

d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer
accelerator.backward(d_loss_real, retain_graph=True)

# 5. Calculate the discriminator fake adversarial loss terms.
# 4. Calculate the discriminator fake adversarial loss terms.
disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding)
d_logits_fake = disc_output_fake.logits
# Use hinge loss (see section 3.2, Equation 3 of paper)
d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake))
Expand Down

0 comments on commit ab46142

Please sign in to comment.