Skip to content

Commit

Permalink
improve MTL
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Jan 6, 2025
1 parent 52ae7d7 commit 877b37d
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 20 deletions.
18 changes: 18 additions & 0 deletions multimolecule/module/criterions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
# <https://multimolecule.danling.org/about/license-faq>.


from .balancer import (
DynamicWeightAverageBalancer,
EqualWeightBalancer,
GeometricLossBalancer,
GradNormBalancer,
LossBalancer,
LossBalancerRegistry,
RandomLossWeightBalancer,
UncertaintyWeightBalancer,
)
from .binary import BCEWithLogitsLoss
from .generic import Criterion
from .multiclass import CrossEntropyLoss
Expand All @@ -29,9 +39,17 @@

__all__ = [
"CriterionRegistry",
"LossBalancerRegistry",
"Criterion",
"MSELoss",
"BCEWithLogitsLoss",
"CrossEntropyLoss",
"MultiLabelSoftMarginLoss",
"LossBalancer",
"EqualWeightBalancer",
"RandomLossWeightBalancer",
"GeometricLossBalancer",
"UncertaintyWeightBalancer",
"DynamicWeightAverageBalancer",
"GradNormBalancer",
]
196 changes: 196 additions & 0 deletions multimolecule/module/criterions/balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

import math
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from chanfig import Registry
from torch import Tensor

LossBalancerRegistry = Registry()


class LossBalancer(nn.Module):
"""Base class for loss balancers in multi-task learning.
This class provides an interface for implementing various strategies
to balance the losses of different tasks in a multi-task learning setup.
"""

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
"""Compute the balanced total loss.
Args:
losses (Dict[str, Tensor]): A dictionary of task names to their respective losses.
Returns:
Tensor: The computed balanced loss.
"""
return {k: v["loss"] for k, v in ret.items()}


@LossBalancerRegistry.register("equal", default=True)
class EqualWeightBalancer(LossBalancer):
"""Equal Weighting Balancer.
This method assigns equal weight to each task's loss, effectively averaging the losses across all tasks.
"""

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
losses = super().forward(ret)
return sum(losses.values()) / len(losses)


@LossBalancerRegistry.register("random")
class RandomLossWeightBalancer(LossBalancer):
"""Random Loss Weighting Balancer.
This method assigns random weights to each task's loss, which are sampled from a softmax distribution,
as described in the paper "Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning"
by Liang et al. (https://openreview.net/forum?id=jjtFD8A1Wx).
"""

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
losses = super().forward(ret)
loss = torch.stack(list(losses.values()))
weight = F.softmax(torch.randn(len(losses), device=loss.device, dtype=loss.dtype), dim=-1)
return loss.T @ weight


@LossBalancerRegistry.register("geometric")
class GeometricLossBalancer(LossBalancer):
"""Geometric Loss Strategy Balancer.
This method computes the geometric mean of the task losses, which can be useful for balancing tasks with different
scales, as described in the paper "MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for
Multi-Task Learning" by Chennupati et al. (https://arxiv.org/abs/1904.08492).
"""

def forward(self, losses: Dict[str, Tensor]) -> Tensor:
return math.prod(losses.values()) ** (1 / len(losses))


@LossBalancerRegistry.register("uncertainty")
class UncertaintyWeightBalancer(LossBalancer):
"""Uncertainty Weighting Balancer.
This method uses task uncertainty to weight the losses, as described in the paper "Multi-Task Learning Using
Uncertainty to Weigh Losses for Scene Geometry and Semantics" by Kendall et al. (https://arxiv.org/abs/1705.07115).
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_vars = nn.ParameterDict()

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
losses = super().forward(ret)
for task in losses.keys():
if task not in self.log_vars:
self.log_vars[task] = nn.Parameter(torch.zeros(1))

weighted_losses = [
torch.exp(-self.log_vars[task]) * loss + self.log_vars[task] for task, loss in losses.items()
]
return sum(weighted_losses) / len(weighted_losses)


@LossBalancerRegistry.register("dynamic")
class DynamicWeightAverageBalancer(LossBalancer):
"""Dynamic Weight Average Balancer.
This method dynamically adjusts the weights of task losses based on their relative changes over time, as described
in the paper "End-to-End Multi-Task Learning with Attention" by Liu et al. (https://arxiv.org/abs/1803.10704).
"""

def __init__(self, *args, temperature: float = 2.0, **kwargs):
super().__init__(*args, **kwargs)
self.temperature = temperature
self.task_losses_history: List[List[float]] = []

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
losses = super().forward(ret)
if len(self.task_losses_history) < 2:
self.task_losses_history.append([loss.item() for loss in losses.values()])
return sum(losses.values()) / len(losses)

curr_losses = [loss.item() for loss in losses.values()]
prev_losses = self.task_losses_history[-1]
loss_ratios = [c / (p + 1e-8) for c, p in zip(curr_losses, prev_losses)]

exp_weights = torch.exp(torch.tensor(loss_ratios) / self.temperature)
weights = len(losses) * F.softmax(exp_weights, dim=-1)

self.task_losses_history.append(curr_losses)
if len(self.task_losses_history) > 2:
self.task_losses_history.pop(0)

return sum(w * l for w, l in zip(weights, losses.values())) / len(losses)


@LossBalancerRegistry.register("gradnorm")
class GradNormBalancer(LossBalancer):
"""GradNorm Balancer.
This method balances task losses by normalizing gradients, as described in the paper "GradNorm: Gradient
Normalization for Adaptive Loss Balancing in Deep Multitask Networks" by Chen et al.
(https://arxiv.org/abs/1711.02257).
"""

def __init__(self, *args, alpha: float = 1.5, **kwargs):
super().__init__(*args, **kwargs)
self.alpha = alpha
self.task_weights = nn.ParameterDict()
self.initial_losses: Dict[str, Tensor] = {}

def forward(self, ret: Dict[str, Tensor]) -> Tensor:
losses = super().forward(ret)

for task in losses.keys():
if task not in self.task_weights:
self.task_weights[task] = nn.Parameter(torch.ones(1, device=losses[task].device))
self.initial_losses[task] = losses[task].detach()

loss_ratios = {task: loss / (self.initial_losses[task] + 1e-8) for task, loss in losses.items()}
avg_loss_ratio = sum(loss_ratios.values()) / len(loss_ratios)

relative_inverse_rates = {
task: (ratio / (avg_loss_ratio + 1e-8)) ** self.alpha for task, ratio in loss_ratios.items()
}

weighted_losses = {task: self.task_weights[task] * loss for task, loss in losses.items()}
grad_norms = {
task: torch.norm(torch.autograd.grad(weighted_loss, self.task_weights[task], retain_graph=True)[0])
for task, weighted_loss in weighted_losses.items()
}
mean_grad_norm = sum(grad_norms.values()) / len(grad_norms)

for task in losses.keys():
target_grad = mean_grad_norm * relative_inverse_rates[task]
grad_norm = grad_norms[task]
self.task_weights[task].data = torch.clamp(
self.task_weights[task] * (target_grad / (grad_norm + 1e-8)), min=0.0
)
weight_sum = sum(w.item() for w in self.task_weights.values())
scale = len(losses) / (weight_sum + 1e-8)
for task in losses.keys():
self.task_weights[task].data *= scale

return sum(self.task_weights[task] * loss for task, loss in losses.items())
11 changes: 10 additions & 1 deletion multimolecule/module/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import Tensor, nn

from .backbones import BackboneRegistry
from .criterions.balancer import LossBalancerRegistry
from .heads import HeadRegistry
from .necks import NeckRegistry
from .registry import ModelRegistry
Expand All @@ -42,10 +43,12 @@ def __init__(
self,
backbone: dict,
heads: dict,
balancer: dict | None = None,
neck: dict | None = None,
max_length: int = 1024,
truncation: bool = False,
probing: bool = False,
config: dict | None = None,
):
super().__init__()

Expand Down Expand Up @@ -87,6 +90,8 @@ def __init__(
for param in self.backbone.parameters():
param.requires_grad = False

self.balancer = LossBalancerRegistry.build(balancer)

def forward(
self,
sequence: NestedTensor | Tensor,
Expand All @@ -99,9 +104,13 @@ def forward(
output, _ = self.backbone(sequence, discrete, continuous)
if self.neck is not None:
output = self.neck(**output)
if not labels:
return output
for task, label in labels.items():
ret[task] = self.heads[task](output, input_ids=sequence, labels=label)
return ret
if len(ret) == 1:
return ret, ret[task]["loss"]
return ret, self.balancer(ret)

def trainable_parameters(
self,
Expand Down
25 changes: 6 additions & 19 deletions multimolecule/runners/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import math
import os
from functools import cached_property, partial
from typing import Any, Tuple
Expand All @@ -28,7 +27,6 @@
from datasets import disable_progress_bars, get_dataset_split_names
from lazy_imports import try_import
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from tqdm import tqdm
from transformers import AutoTokenizer
Expand Down Expand Up @@ -66,7 +64,8 @@ def __init__(self, config: MultiMoleculeConfig):
ema_enabled = self.config.ema.pop("enabled", False)
if ema_enabled:
ema.check()
self.ema = EMA(self.model, coerce_dtype=True)
self.config.ema.setdefault("coerce_dtype", True)
self.ema = EMA(self.model, **self.config.ema)
self.config.ema.enabled = ema_enabled
if self.config.training:
optim_name = self.config.optim.pop("name", "AdamW")
Expand All @@ -79,7 +78,6 @@ def __init__(self, config: MultiMoleculeConfig):
self.config.optim.pretrained_ratio = pretrained_ratio
if self.config.sched:
self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
self.balance = self.config.get("balance", "ew")
self.train_metrics = self.build_train_metrics()
self.evaluate_metrics = self.build_evaluate_metrics()

Expand All @@ -99,16 +97,14 @@ def __post_init__(self):

def train_step(self, data) -> Tuple[Any, torch.Tensor]:
with self.autocast(), self.accumulate():
pred = self.model(**data)
loss = self.loss_fn(pred, data)
pred, loss = self.model(**data)
self.advance(loss)
self.metric_fn(pred, data)
return pred, loss

def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
model = self.ema or self.model
pred = model(**data)
loss = self.loss_fn(pred, data)
pred, loss = model(**data)
self.metric_fn(pred, data)
return pred, loss

Expand Down Expand Up @@ -143,6 +139,8 @@ def infer(self, split: str = "inf") -> NestedDict | FlatDict | list:
model = self.ema or self.model
for _, data in tqdm(enumerate(loader), total=len(loader)): # noqa: F402
pred = model(**data)
if isinstance(pred, tuple):
pred, loss = pred
for task, p in pred.items():
preds[task].extend(p["logits"].squeeze(-1).tolist())
if task in data:
Expand All @@ -162,17 +160,6 @@ def infer(self, split: str = "inf") -> NestedDict | FlatDict | list:
return next(iter(preds.values()))
return preds

def loss_fn(self, pred, data):
if self.balance == "rlw":
loss = torch.stack([p["loss"] for p in pred.values()])
weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1)
return loss.T @ weight
if self.balance == "gls":
return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred))
if self.balance != "ew":
warn(f"Unknown balance method {self.balance}, using equal weighting.")
return sum(p["loss"] for p in pred.values()) / len(pred)

def metric_fn(self, pred, data):
metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})
Expand Down

0 comments on commit 877b37d

Please sign in to comment.