Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimenting with alternative siglip loss impl for better dist scaling #971

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def create_loss(args):
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
)

return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
Expand Down
80 changes: 57 additions & 23 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
Expand Down Expand Up @@ -102,8 +104,14 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
image_features,
text_features,
local_loss=self.local_loss,
gather_with_grad=self.gather_with_grad,
rank=self.rank,
world_size=self.world_size,
use_horovod=self.use_horovod,
)

if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
Expand Down Expand Up @@ -158,12 +166,11 @@ def __init__(
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):

clip_loss = torch.tensor(0)

if self.clip_loss_weight:
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
else:
clip_loss = torch.tensor(0, device=logits.device)

caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
Expand Down Expand Up @@ -316,19 +323,17 @@ class SigLipLoss(nn.Module):
"""
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: Optional[str] = None,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
Expand Down Expand Up @@ -361,10 +366,9 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
if self.dist_impl == 'bidir':
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
Expand All @@ -374,7 +378,6 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
Expand All @@ -387,21 +390,27 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

left_rank,
right_rank,
text_features_to_right
)
loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
elif self.dist_impl == "shift":
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

left_rank,
right_rank,
text_features_to_right,
)
loss += self._loss(
image_features,
text_features_from_left,
Expand All @@ -410,5 +419,30 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
negative_only=True,
)
text_features_to_right = text_features_from_left
elif self.dist_impl == "reduce":
for i in range(self.world_size):
text_from_other = torch.distributed.nn.all_reduce(
text_features * (self.rank == i),
torch.distributed.ReduceOp.SUM,
)
loss += float(i != self.rank) * self._loss(
image_features,
text_from_other,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.dist_impl == "gather":
all_text = torch.distributed.nn.all_gather(text_features)
for i in range(self.world_size):
loss += float(i != self.rank) * self._loss(
image_features,
all_text[i],
logit_scale,
logit_bias,
negative_only=True,
)
else:
assert False

return {"contrastive_loss": loss} if output_dict else loss
6 changes: 6 additions & 0 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,12 @@ def parse_args(args):
action="store_true",
help='Use SigLip (sigmoid) loss.'
)
parser.add_argument(
"--loss-dist-impl",
default=None,
type=str,
help='A string to specify a specific distributed loss implementation.'
)

args = parser.parse_args(args)

Expand Down
Loading