Skip to content

Commit

Permalink
use new loss API
Browse files Browse the repository at this point in the history
  • Loading branch information
jonflynng committed Oct 29, 2024
1 parent 1ff4a6a commit d50b0f0
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 75 deletions.
56 changes: 56 additions & 0 deletions src/transformers/loss/loss_encodec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn.functional as F

def EncodecLoss(
model,
input_values,
audio_values
):
"""
Computes the reconstruction and commitment losses for the Encodec model.
Args:
model: The EncodecModel instance.
input_values (torch.Tensor): Original input audio.
audio_values (torch.Tensor): Reconstructed audio from the model.
audio_codes (torch.Tensor): Discrete codes from the quantizer.
padding_mask (torch.Tensor): Padding mask used during encoding.
config: Model configuration.
Returns:
tuple: A tuple containing (reconstruction_loss, commitment_loss).
"""
# Compute commitment loss
embeddings = model.encoder(input_values)
_, quantization_steps = model.quantizer.encode(embeddings, bandwidth=None)

commitment_loss = torch.tensor(0.0, device=input_values.device)
for residual, quantize in quantization_steps:
loss = F.mse_loss(quantize.permute(0, 2, 1), residual.permute(0, 2, 1))
commitment_loss += loss
commitment_loss *= model.commitment_weight

# Compute reconstruction loss
# Time domain loss
time_loss = F.l1_loss(audio_values, input_values)

# Frequency domain loss
scales = [2**i for i in range(5, 12)]
frequency_loss = 0.0
for scale in scales:
n_fft = scale
hop_length = scale // 4
S_x = model.compute_mel_spectrogram(input_values, n_fft, hop_length, n_mels=64)
S_x_hat = model.compute_mel_spectrogram(audio_values, n_fft, hop_length, n_mels=64)
l1 = F.l1_loss(S_x_hat, S_x)
l2 = F.mse_loss(S_x_hat, S_x)
frequency_loss += l1 + l2

frequency_loss = frequency_loss / (len(scales) * 2)

# Combine losses
lambda_t = 1.0 # You can adjust these weights if needed
lambda_f = 1.0
reconstruction_loss = lambda_t * time_loss + lambda_f * frequency_loss

return reconstruction_loss, commitment_loss
2 changes: 2 additions & 0 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss
from .loss_encodec import EncodecLoss


def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
Expand Down Expand Up @@ -111,4 +112,5 @@ def ForTokenClassification(logits, labels, config, **kwargs):
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"Encodec": EncodecLoss,
}
1 change: 1 addition & 0 deletions src/transformers/models/encodec/configuration_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class EncodecConfig(PretrainedConfig):
```"""

model_type = "encodec"
loss_type = "Encodec"

def __init__(
self,
Expand Down
25 changes: 11 additions & 14 deletions src/transformers/models/encodec/loss_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
import torch.nn.functional as F
from torch import autograd


"""
Balancer code directly copied from: https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py
"""


class Balancer:
"""Loss balancer.
Expand Down Expand Up @@ -49,14 +46,14 @@ class Balancer:
"""

def __init__(
self,
weights: tp.Dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
epsilon: float = 1e-12,
monitor: bool = False,
self,
weights: tp.Dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
epsilon: float = 1e-12,
monitor: bool = False,
):
self.weights = weights
self.per_batch_item = per_batch_item
Expand Down Expand Up @@ -164,7 +161,7 @@ def _update(metrics: tp.Dict[str, tp.Any], weight: float = 1) -> tp.Dict[str, fl


def compute_discriminator_loss(
real_logits: List[torch.Tensor], fake_logits: List[torch.Tensor], num_discriminators: int
real_logits: List[torch.Tensor], fake_logits: List[torch.Tensor], num_discriminators: int
) -> torch.Tensor:
"""
Compute the discriminator loss based on real and fake logits.
Expand Down Expand Up @@ -201,7 +198,7 @@ def compute_generator_adv_loss(fake_logits: List[torch.Tensor], num_discriminato


def compute_feature_matching_loss(
real_features: List[List[torch.Tensor]], fake_features: List[List[torch.Tensor]], num_discriminators: int
real_features: List[List[torch.Tensor]], fake_features: List[List[torch.Tensor]], num_discriminators: int
):
"""
Compute the feature matching loss between real and fake features.
Expand All @@ -219,4 +216,4 @@ def compute_feature_matching_loss(
for real_feat, fake_feat in zip(real_features[k], fake_features[k]):
fm_loss += F.l1_loss(fake_feat, real_feat.detach()) / torch.mean(torch.abs(real_feat.detach()))
fm_loss /= num_discriminators * len(real_features[0])
return fm_loss
return fm_loss
67 changes: 16 additions & 51 deletions src/transformers/models/encodec/modeling_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,65 +864,30 @@ def forward(
if audio_codes is not None and audio_scales is None:
raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")

reconstruction_loss = None
commitment_loss = None

if audio_scales is None and audio_codes is None:
if audio_codes is None:
audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, return_dict=False)

if return_loss:
embeddings = self.encoder(input_values)
_, quantization_steps = self.quantizer.encode(embeddings, bandwidth)

commitment_loss = torch.tensor(0.0, device=input_values.device)
for residual, quantize in quantization_steps:
loss = F.mse_loss(quantize.permute(0, 2, 1), residual.permute(0, 2, 1))
commitment_loss += loss
commitment_loss *= self.commitment_weight

decoded_output = self.decode(audio_codes, audio_scales)
audio_values = decoded_output.audio_values[:, :, : input_values.shape[-1]]

if return_loss:
# Time domain loss
time_loss = F.l1_loss(audio_values, input_values)
print(f"Time loss: {time_loss.item()}")

# Frequency domain loss
scales = [2**i for i in range(5, 12)]
frequency_loss = 0.0
for scale in scales:
n_fft = scale
hop_length = scale // 4
S_x = self.compute_mel_spectrogram(input_values, n_fft, hop_length, n_mels=64)
S_x_hat = self.compute_mel_spectrogram(audio_values, n_fft, hop_length, n_mels=64)
l1 = F.l1_loss(S_x_hat, S_x)
l2 = F.mse_loss(S_x_hat, S_x)
frequency_loss += l1 + l2

frequency_loss = frequency_loss / (len(scales) * 2)
print(f"Average frequency loss: {frequency_loss.item()}")

# Combine losses
lambda_t = 1.0 # look at this further, not sure why the need for a weight here
lambda_f = 1.0
reconstruction_loss = lambda_t * time_loss + lambda_f * frequency_loss
print(f"Reconstruction loss: {reconstruction_loss.item()}")

if commitment_loss is not None:
print(f"Commitment loss: {commitment_loss.item()}")

audio_values_to_return = self.decode(audio_codes, audio_scales, padding_mask, return_dict=return_dict)[0]
outputs = {
"audio_codes": audio_codes,
"audio_values": audio_values,
}

if self.training and self.loss_function is not None:
reconstruction_loss, commitment_loss = self.loss_function(
model=self,
input_values=input_values,
audio_values=audio_values
)
outputs["reconstruction_loss"] = reconstruction_loss
outputs["commitment_loss"] = commitment_loss

if not return_dict:
return (audio_codes, audio_values_to_return, reconstruction_loss, commitment_loss)
return tuple(outputs.values())

return EncodecOutput(
audio_codes=audio_codes,
audio_values=audio_values_to_return,
reconstruction_loss=reconstruction_loss,
commitment_loss=commitment_loss,
)
return EncodecOutput(**outputs)


"""
Expand Down
21 changes: 11 additions & 10 deletions tests/models/encodec/test_modeling_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def test_training_with_discriminator(self):
discriminator_optimizer.zero_grad()

# Generate fake audio with the generator
outputs = model(input_values, return_dict=True, return_loss=True)
with torch.no_grad():
outputs = model(input_values, return_dict=True)
fake_audio = outputs.audio_values.detach() # Detach to prevent gradients flowing to the generator

real_logits, _ = discriminator(real_audio)
Expand All @@ -258,9 +259,9 @@ def test_training_with_discriminator(self):
# Train Generator
generator_optimizer.zero_grad()

# Generate fake audio again (this time gradients flow back to the generator)
outputs = model(input_values, return_dict=True, return_loss=True)
fake_audio = outputs.audio_values # Do not detach
# Generate fake audio and compute losses
outputs = model(input_values, return_dict=True)
fake_audio = outputs.audio_values

# Compute generator adversarial loss and feature matching loss
fake_logits, fake_features = discriminator(fake_audio)
Expand All @@ -271,14 +272,13 @@ def test_training_with_discriminator(self):
fake_logits=fake_logits, num_discriminators=discriminator.num_discriminators
)

# Feature matching loss (Equation 2 in paper)
# Feature matching loss
fm_loss = compute_feature_matching_loss(
real_features=real_features,
fake_features=fake_features,
num_discriminators=discriminator.num_discriminators,
)

# Combine losses using the Balancer
losses_to_balance = {
"reconstruction_loss": outputs.reconstruction_loss,
"g_adv_loss": g_adv_loss,
Expand All @@ -288,9 +288,8 @@ def test_training_with_discriminator(self):
# Model output (the reconstructed audio)
model_output = outputs.audio_values

balancer.backward(losses_to_balance, model_output)
balancer.backward(losses=losses_to_balance, input=model_output)

# Add commitment loss separately as per paper
if outputs.commitment_loss is not None:
outputs.commitment_loss.backward()

Expand All @@ -303,10 +302,12 @@ def test_training_with_discriminator(self):
print("Discriminator not updated this epoch")
print(f"Generator adversarial loss: {g_adv_loss.item():.4f}")
print(f"Feature matching loss: {fm_loss.item():.4f}")
print(f"Reconstruction loss (no commit): {outputs.reconstruction_loss.item():.4f}")
print(f"Reconstruction loss: {outputs.reconstruction_loss.item():.4f}")
if outputs.commitment_loss is not None:
print(f"Commitment loss: {outputs.commitment_loss.item():.4f}")
total_gen_loss = outputs.reconstruction_loss.item() + g_adv_loss.item() + fm_loss.item()
total_gen_loss = (
outputs.reconstruction_loss.item() + g_adv_loss.item() + fm_loss.item()
)
if outputs.commitment_loss is not None:
total_gen_loss += outputs.commitment_loss.item()
print(f"Total generator loss (before balancing): {total_gen_loss:.4f}\n")
Expand Down

0 comments on commit d50b0f0

Please sign in to comment.