Skip to content

Commit

Permalink
run formatting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Flynn committed Oct 19, 2024
1 parent 181c347 commit ed808fa
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 200 deletions.
17 changes: 11 additions & 6 deletions src/transformers/models/encodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
_import_structure = {
"configuration_encodec": ["EncodecConfig"],
"feature_extraction_encodec": ["EncodecFeatureExtractor"],
"loss_encodec": ["compute_discriminator_loss", "compute_generator_adv_loss", "compute_feature_matching_loss", "Balancer"],
"loss_encodec": [
"compute_discriminator_loss",
"compute_generator_adv_loss",
"compute_feature_matching_loss",
"Balancer",
],
}

try:
Expand All @@ -45,10 +50,10 @@
)
from .feature_extraction_encodec import EncodecFeatureExtractor
from .loss_encodec import (
Balancer,
compute_discriminator_loss,
compute_generator_adv_loss,
compute_feature_matching_loss,
Balancer
compute_generator_adv_loss,
)

try:
Expand All @@ -58,13 +63,13 @@
pass
else:
from .modeling_encodec import (
EncodecDiscriminator,
EncodecDiscriminatorConfig,
EncodecModel,
EncodecPreTrainedModel,
EncodecDiscriminatorConfig,
EncodecDiscriminator,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
56 changes: 34 additions & 22 deletions src/transformers/models/encodec/loss_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import torch.nn.functional as F


'''
"""
Balancer code directly copied from: https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py
'''
from collections import defaultdict
"""
import typing as tp
from collections import defaultdict
from typing import List

from torch import autograd
from torch.nn.utils import spectral_norm, weight_norm
import einops


class Balancer:
"""Loss balancer.
Expand Down Expand Up @@ -48,9 +48,16 @@ class Balancer:
monitor (bool): Whether to store additional ratio for each loss key in metrics.
"""

def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True, total_norm: float = 1.,
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
monitor: bool = False):
def __init__(
self,
weights: tp.Dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
epsilon: float = 1e-12,
monitor: bool = False,
):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm
Expand All @@ -68,7 +75,7 @@ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
norms = {}
grads = {}
for name, loss in losses.items():
grad, = autograd.grad(loss, [input], retain_graph=True, allow_unused=True)
(grad,) = autograd.grad(loss, [input], retain_graph=True, allow_unused=True)
if grad is not None:
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
Expand All @@ -87,7 +94,7 @@ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
self._metrics = {}
if self.monitor:
for k, v in avg_norms.items():
self._metrics[f'ratio_{k}'] = v / total
self._metrics[f"ratio_{k}"] = v / total

total_weights = sum([self.weights[k] for k in avg_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}
Expand All @@ -113,23 +120,27 @@ def world_size():
def is_distributed():
return world_size() > 1


def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def average_metrics(metrics: tp.Dict[str, float], count=1.):


def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unnormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))


def averager(beta: float = 1):
"""
Exponential Moving Average callback.
Expand All @@ -148,12 +159,12 @@ def _update(metrics: tp.Dict[str, tp.Any], weight: float = 1) -> tp.Dict[str, fl
total[key] = total[key] * beta + weight * float(value)
fix[key] = fix[key] * beta + weight
return {key: tot / fix[key] for key, tot in total.items()}

return _update


def compute_discriminator_loss(
real_logits: List[torch.Tensor],
fake_logits: List[torch.Tensor],
num_discriminators: int
real_logits: List[torch.Tensor], fake_logits: List[torch.Tensor], num_discriminators: int
) -> torch.Tensor:
"""
Compute the discriminator loss based on real and fake logits.
Expand All @@ -171,10 +182,8 @@ def compute_discriminator_loss(
loss += torch.mean(F.relu(1 - real_logit)) + torch.mean(F.relu(1 + fake_logit))
return loss / num_discriminators

def compute_generator_adv_loss(
fake_logits: List[torch.Tensor],
num_discriminators: int
) -> torch.Tensor:

def compute_generator_adv_loss(fake_logits: List[torch.Tensor], num_discriminators: int) -> torch.Tensor:
"""
Compute the generator adversarial loss using fake logits.
Expand All @@ -190,7 +199,10 @@ def compute_generator_adv_loss(
loss += torch.mean(F.relu(1 - fake_logit))
return loss / num_discriminators

def compute_feature_matching_loss(real_features: List[List[torch.Tensor]], fake_features: List[List[torch.Tensor]], num_discriminators: int):

def compute_feature_matching_loss(
real_features: List[List[torch.Tensor]], fake_features: List[List[torch.Tensor]], num_discriminators: int
):
"""
Compute the feature matching loss between real and fake features.
Expand All @@ -206,5 +218,5 @@ def compute_feature_matching_loss(real_features: List[List[torch.Tensor]], fake_
for k in range(num_discriminators):
for real_feat, fake_feat in zip(real_features[k], fake_features[k]):
fm_loss += F.l1_loss(fake_feat, real_feat.detach()) / torch.mean(torch.abs(real_feat.detach()))
fm_loss /= (num_discriminators * len(real_features[0]))
return fm_loss
fm_loss /= num_discriminators * len(real_features[0])
return fm_loss
Loading

0 comments on commit ed808fa

Please sign in to comment.