Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Label-Dependent Loss Functions (e.g., Supervised Contrastive Loss) #32

Open
penguinwang96825 opened this issue Jul 19, 2024 · 0 comments

Comments

@penguinwang96825
Copy link

First of all, thank you for developing GradCache and making it available for the community. It's been incredibly useful for my work.

Currently, GradCache supports loss functions that do not require label information, such as SimCLR. However, I would like to use GradCache with label-dependent loss functions like the Supervised Contrastive (SupCon) loss.

The current implementation of contrastive_loss in the README only supports inputs without labels. Here is a sample code snippet from the README for reference:

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from grad_cache.functional import cached, cat_input_tensor

@cached
@autocast()
def call_model(model, input):
    return model(**input).pooler_output

@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
    target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
    scores = torch.matmul(x, y.transpose(0, 1))
    return F.cross_entropy(scores, target=target)

Could you provide guidance on how to incorporate label information in the contrastive_loss function with GradCache? Specifically, how can we adapt the current GradCache framework to support supervised loss functions like the SupCon loss?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant