Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with the softmax over the targets #24

Open
MarcoChain opened this issue Nov 9, 2024 · 0 comments
Open

Problem with the softmax over the targets #24

MarcoChain opened this issue Nov 9, 2024 · 0 comments

Comments

@MarcoChain
Copy link

MarcoChain commented Nov 9, 2024

Hi everyone,

After examining clip.py and modules.py, I noticed a few issues. Starting from the end, the symmetric cross-entropy seems partially incorrect. Specifically:

targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')

While I appreciate the introduction of "soft targets" over the original one-hot encoding, I believe the softmax should be included directly within the cross_entropy function to correctly apply over both rows and columns. Transposing the targets is not equivalent to applying softmax on the first dimension (dim=0), which could cause convergence issues, especially at low temperature during initial training. As temperature increases, targets start resembling an identity matrix, making targets close to targets.T. Additionally, I suggest clamping the temperature to prevent scaling logits by more than 100, as recommended in the original paper to avoid instability.

My suggested code adjustment would be:

t = torch.clamp(self.temperature.exp(), max=100)
...
targets = (images_similarity + texts_similarity) / 2 * t
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')

...

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    targets = F.softmax(targets, dim=-1)
    ...

To ensure that the model focus on the directional similarity between image and text embeddings, rather than their magnitudes , I would also remove the final layer norm from the ProjectionHead module. This layer adjusts the input's mean and standard deviation but does not produce unit norm, so replacing it with L2 normalization on image_embeddings and text_embeddings would better align with the original loss formulation:

image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
text_embeddings  = text_embeddings  / text_embeddings.norm(dim=1, keepdim=True)

These changes should help smooth integration with the initial CLIP implementation. Let me know if anything is unclear or if I've missed anything. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant