Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix top_k for multiclassf1score #2839

Open
wants to merge 20 commits into
base: master
Choose a base branch
from

Conversation

rittik9
Copy link
Contributor

@rittik9 rittik9 commented Nov 21, 2024

What does this PR do?

Fixes #1653

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2839.org.readthedocs.build/en/2839/

@eneserdo
Copy link

eneserdo commented Nov 22, 2024

Thanks for the effort. Imho, you could have handled the "refining" on directly preds tensor using something like this:

preds_topk = torch.argsort(preds, dim=-1, descending=True)[:, :top_k]
preds_top1 = preds_topk[:, 0]
preds=torch.where((target.view(-1, 1) == preds_topk).sum(dim=-1).bool(), target, preds_top1)

Which is more compact way of doing the same job. (Cloning and reshaping are omitted here)

Also, these changes will break the current tests for all top_k related classes/functions e.g. for recall, accuracy, f1, so on so forth. I think it is important to re-write these tests. Additionally, maybe for the topk accuracy you can take the scikit learn's top_k_accuracy_score as a reference.

@rittik9
Copy link
Contributor Author

rittik9 commented Nov 22, 2024

Thanks for your suggestions.I've noticed some of the tests have failed. I'm working on them. I am also comparing them with other library implementations. I'll keep updating here.

@rittik9 rittik9 marked this pull request as ready for review November 23, 2024 13:26
Copy link

codecov bot commented Nov 23, 2024

Codecov Report

Attention: Patch coverage is 11.11111% with 8 lines in your changes missing coverage. Please review.

Project coverage is 41%. Comparing base (0d3494f) to head (aa8ff49).

❗ There is a different number of reports uploaded between BASE (0d3494f) and HEAD (aa8ff49). Click for more details.

HEAD has 226 uploads less than BASE
Flag BASE (0d3494f) HEAD (aa8ff49)
macOS 15 3
python3.10 45 9
cpu 70 14
torch2.0.1 10 2
torch2.0.1+cpu 15 3
Windows 10 2
python3.12 15 3
torch2.5.0 5 1
torch2.5.0+cpu 5 1
gpu 1 0
unittest 1 0
Linux 45 9
torch2.4.1+cu121 10 2
torch2.1.2+cpu 5 1
python3.11 5 1
torch2.3.1+cpu 5 1
torch2.2.2+cpu 5 1
torch2.5.0+cu124 10 2
python3.9 5 1
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #2839     +/-   ##
========================================
- Coverage      69%     41%    -28%     
========================================
  Files         346     332     -14     
  Lines       19129   18962    -167     
========================================
- Hits        13227    7736   -5491     
- Misses       5902   11226   +5324     
---- 🚨 Try these New Features:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

top_k for multiclassf1score is not working correctly
3 participants